├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── questions---help---support.md └── workflows │ └── python-publish.yml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── datasets └── README.md ├── examples ├── colab_record_example.ipynb ├── cpp │ ├── README.md │ ├── silero-vad-onnx.cpp │ └── wav.h ├── cpp_libtorch │ ├── README.md │ ├── aepyx.wav │ ├── main.cc │ ├── silero │ ├── silero_torch.cc │ ├── silero_torch.h │ └── wav.h ├── csharp │ ├── Program.cs │ ├── SileroSpeechSegment.cs │ ├── SileroVadDetector.cs │ ├── SileroVadOnnxModel.cs │ ├── VadDotNet.csproj │ └── resources │ │ └── put_model_here.txt ├── go │ ├── README.md │ ├── cmd │ │ └── main.go │ ├── go.mod │ └── go.sum ├── haskell │ ├── README.md │ ├── app │ │ └── Main.hs │ ├── example.cabal │ ├── package.yaml │ ├── stack.yaml │ └── stack.yaml.lock ├── java-example │ ├── pom.xml │ └── src │ │ └── main │ │ └── java │ │ └── org │ │ └── example │ │ ├── App.java │ │ ├── SlieroVadDetector.java │ │ └── SlieroVadOnnxModel.java ├── java-wav-file-example │ └── src │ │ └── main │ │ └── java │ │ └── org │ │ └── example │ │ ├── App.java │ │ ├── SileroSpeechSegment.java │ │ ├── SileroVadDetector.java │ │ └── SileroVadOnnxModel.java ├── microphone_and_webRTC_integration │ ├── README.md │ └── microphone_and_webRTC_integration.py ├── parallel_example.ipynb ├── pyaudio-streaming │ ├── README.md │ └── pyaudio-streaming-examples.ipynb └── rust-example │ ├── .gitignore │ ├── Cargo.lock │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── main.rs │ ├── silero.rs │ ├── utils.rs │ └── vad_iter.rs ├── files └── silero_logo.jpg ├── hubconf.py ├── pyproject.toml ├── silero-vad.ipynb ├── src └── silero_vad │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── silero_vad.jit │ ├── silero_vad.onnx │ ├── silero_vad_16k_op15.onnx │ └── silero_vad_half.onnx │ ├── model.py │ └── utils_vad.py └── tuning ├── README.md ├── __init__.py ├── config.yml ├── example_dataframe.feather ├── search_thresholds.py ├── tune.py └── utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: Bug report - [X] 5 | labels: bug 6 | assignees: snakers4 7 | 8 | --- 9 | 10 | ## 🐛 Bug 11 | 12 | 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. 19 | 2. 20 | 3. 21 | 22 | 23 | 24 | ## Expected behavior 25 | 26 | 27 | 28 | ## Environment 29 | 30 | Please copy and paste the output from this 31 | [environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) 32 | (or fill out the checklist below manually). 33 | 34 | You can get the script and run it with: 35 | ``` 36 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 37 | # For security purposes, please check the contents of collect_env.py before running it. 38 | python collect_env.py 39 | ``` 40 | 41 | - PyTorch Version (e.g., 1.0): 42 | - OS (e.g., Linux): 43 | - How you installed PyTorch (`conda`, `pip`, source): 44 | - Build command you used (if compiling from source): 45 | - Python version: 46 | - CUDA/cuDNN version: 47 | - GPU models and configuration: 48 | - Any other relevant information: 49 | 50 | ## Additional context 51 | 52 | 53 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: Feature request - [X] 5 | labels: enhancement 6 | assignees: snakers4 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Pitch 18 | 19 | 20 | 21 | ## Alternatives 22 | 23 | 24 | 25 | ## Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions---help---support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Questions / Help / Support 3 | about: Ask for help, support or ask a question 4 | title: "❓ Questions / Help / Support" 5 | labels: help wanted 6 | assignees: snakers4 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | 12 | We have a [wiki](https://github.com/snakers4/silero-models/wiki) available for our users. Please make sure you have checked it out first. 13 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | push: 13 | tags: 14 | - '*' 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: Set up Python 27 | uses: actions/setup-python@v3 28 | with: 29 | python-version: '3.x' 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install build 34 | - name: Build package 35 | run: python -m build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | title: "Silero VAD" 4 | authors: 5 | - family-names: "Silero Team" 6 | email: "hello@silero.ai" 7 | type: software 8 | repository-code: "https://github.com/snakers4/silero-vad" 9 | license: MIT 10 | abstract: "Pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier" 11 | preferred-citation: 12 | type: software 13 | authors: 14 | - family-names: "Silero Team" 15 | email: "hello@silero.ai" 16 | title: "Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier" 17 | year: 2024 18 | publisher: "GitHub" 19 | journal: "GitHub repository" 20 | howpublished: "https://github.com/snakers4/silero-vad" 21 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at aveysov@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-present Silero Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Mailing list : test](http://img.shields.io/badge/Email-gray.svg?style=for-the-badge&logo=gmail)](mailto:hello@silero.ai) [![Mailing list : test](http://img.shields.io/badge/Telegram-blue.svg?style=for-the-badge&logo=telegram)](https://t.me/silero_speech) [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-MIT-lightgrey.svg?style=for-the-badge)](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [![downloads](https://img.shields.io/pypi/dm/silero-vad?style=for-the-badge)](https://pypi.org/project/silero-vad/) 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb) 4 | 5 | ![header](https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png) 6 | 7 |
8 |

Silero VAD

9 |
10 | 11 | **Silero VAD** - pre-trained enterprise-grade [Voice Activity Detector](https://en.wikipedia.org/wiki/Voice_activity_detection) (also see our [STT models](https://github.com/snakers4/silero-models)). 12 | 13 |
14 | 15 |

16 | 17 |

18 | 19 | 20 |
21 | Real Time Example 22 | 23 | https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4 24 | 25 | Please note, that video loads only if you are logged in your GitHub account. 26 | 27 |
28 | 29 |
30 | 31 |

Fast start

32 |
33 | 34 |
35 | Dependencies 36 | 37 | System requirements to run python examples on `x86-64` systems: 38 | 39 | - `python 3.8+`; 40 | - 1G+ RAM; 41 | - A modern CPU with AVX, AVX2, AVX-512 or AMX instruction sets. 42 | 43 | Dependencies: 44 | 45 | - `torch>=1.12.0`; 46 | - `torchaudio>=0.12.0` (for I/O only); 47 | - `onnxruntime>=1.16.1` (for ONNX model usage). 48 | 49 | Silero VAD uses torchaudio library for audio I/O (`torchaudio.info`, `torchaudio.load`, and `torchaudio.save`), so a proper audio backend is required: 50 | 51 | - Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`; 52 | - Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2; 53 | - Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`. 54 | 55 | If you are planning to run the VAD using solely the `onnx-runtime`, it will run on any other system architectures where onnx-runtume is [supported](https://onnxruntime.ai/getting-started). In this case please note that: 56 | 57 | - You will have to implement the I/O; 58 | - You will have to adapt the existing wrappers / examples / post-processing for your use-case. 59 | 60 |
61 | 62 | **Using pip**: 63 | `pip install silero-vad` 64 | 65 | ```python3 66 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps 67 | model = load_silero_vad() 68 | wav = read_audio('path_to_audio_file') 69 | speech_timestamps = get_speech_timestamps( 70 | wav, 71 | model, 72 | return_seconds=True, # Return speech timestamps in seconds (default is samples) 73 | ) 74 | ``` 75 | 76 | **Using torch.hub**: 77 | ```python3 78 | import torch 79 | torch.set_num_threads(1) 80 | 81 | model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad') 82 | (get_speech_timestamps, _, read_audio, _, _) = utils 83 | 84 | wav = read_audio('path_to_audio_file') 85 | speech_timestamps = get_speech_timestamps( 86 | wav, 87 | model, 88 | return_seconds=True, # Return speech timestamps in seconds (default is samples) 89 | ) 90 | ``` 91 | 92 |
93 | 94 |

Key Features

95 |
96 | 97 | - **Stellar accuracy** 98 | 99 | Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks. 100 | 101 | - **Fast** 102 | 103 | One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster. 104 | 105 | - **Lightweight** 106 | 107 | JIT model is around two megabytes in size. 108 | 109 | - **General** 110 | 111 | Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels. 112 | 113 | - **Flexible sampling rate** 114 | 115 | Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate). 116 | 117 | - **Highly Portable** 118 | 119 | Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available. 120 | 121 | - **No Strings Attached** 122 | 123 | Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock. 124 | 125 |
126 | 127 |

Typical Use Cases

128 |
129 | 130 | - Voice activity detection for IOT / edge / mobile use cases 131 | - Data cleaning and preparation, voice detection in general 132 | - Telephony and call-center automation, voice bots 133 | - Voice interfaces 134 | 135 |
136 |

Links

137 |
138 | 139 | 140 | - [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies) 141 | - [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics) 142 | - [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics) 143 | - [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models) 144 | - [Further reading](https://github.com/snakers4/silero-models#further-reading) 145 | - [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ) 146 | 147 |
148 |

Get In Touch

149 |
150 | 151 | Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news). 152 | 153 | Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for relevant information and [email](mailto:hello@silero.ai) us directly. 154 | 155 | **Citations** 156 | 157 | ``` 158 | @misc{Silero VAD, 159 | author = {Silero Team}, 160 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, 161 | year = {2024}, 162 | publisher = {GitHub}, 163 | journal = {GitHub repository}, 164 | howpublished = {\url{https://github.com/snakers4/silero-vad}}, 165 | commit = {insert_some_commit_here}, 166 | email = {hello@silero.ai} 167 | } 168 | ``` 169 | 170 |
171 |

Examples and VAD-based Community Apps

172 |
173 | 174 | - Example of VAD ONNX Runtime model usage in [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp) 175 | 176 | - Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web 177 | 178 | - [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example), [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp), [C#](https://github.com/snakers4/silero-vad/tree/master/examples/csharp) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) community examples 179 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Датасет Silero-VAD 2 | 3 | > Датасет создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный 4 | интеллект» национальной программы «Цифровая экономика Российской Федерации». 5 | 6 | По ссылкам ниже представлены `.feather` файлы, содержащие размеченные с помощью Silero VAD открытые наборы аудиоданных, а также короткое описание каждого набора данных с примерами загрузки. `.feather` файлы можно открыть с помощью библиотеки `pandas`: 7 | ```python3 8 | import pandas as pd 9 | dataframe = pd.read_feather(PATH_TO_FEATHER_FILE) 10 | ``` 11 | 12 | Каждый `.feather` файл с разметкой содержит следующие колонки: 13 | - `speech_timings` - разметка данного аудио. Это список, содержащий словари вида `{'start': START_SECOND, 'end': END_SECOND}`, где `START_SECOND` и `END_SECOND` - время начала и конца речи в секундах. Количество данных словарей равно количеству речевых аудио отрывков, найденных в данном аудио; 14 | - `language` - ISO код языка данного аудио. 15 | 16 | Колонки, содержащие информацию о загрузке аудио файла различаются и описаны для каждого набора данных ниже. 17 | 18 | **Все данные размечены при временной дискретизации в ~30 миллисекунд (`num_samples` - 512)** 19 | 20 | | Название | Число часов | Число языков | Ссылка | Лицензия | md5sum | 21 | |----------------------|-------------|-------------|--------|----------|----------| 22 | | **Bible.is** | 53,138 | 1,596 | [URL](https://live.bible.is/) | [Уникальная](https://live.bible.is/terms) | ea404eeaf2cd283b8223f63002be11f9 | 23 | | **globalrecordings.net** | 9,743 | 6,171[^1] | [URL](https://globalrecordings.net/en) | CC BY-NC-SA 4.0 | 3c5c0f31b0abd9fe94ddbe8b1e2eb326 | 24 | | **VoxLingua107** | 6,628 | 107 | [URL](https://bark.phon.ioc.ee/voxlingua107/) | CC BY 4.0 | 5dfef33b4d091b6d399cfaf3d05f2140 | 25 | | **Common Voice** | 30,329 | 120 | [URL](https://commonvoice.mozilla.org/en/datasets) | CC0 | 5e30a85126adf74a5fd1496e6ac8695d | 26 | | **MLS** | 50,709 | 8 | [URL](https://www.openslr.org/94/) | CC BY 4.0 | a339d0e94bdf41bba3c003756254ac4e | 27 | | **Итого** | **150,547** | **6,171+** | | | | 28 | 29 | ## Bible.is 30 | 31 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/BibleIs.feather) 32 | 33 | - Колонка `audio_link` содержит ссылки на конкретные аудио файлы. 34 | 35 | ## globalrecordings.net 36 | 37 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/globalrecordings.feather) 38 | 39 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио. 40 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link` 41 | 42 | ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.`` 43 | 44 | ## VoxLingua107 45 | 46 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/VoxLingua107.feather) 47 | 48 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио. 49 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link` 50 | 51 | ## Common Voice 52 | 53 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/common_voice.feather) 54 | 55 | Этот датасет невозможно скачать по статичным ссылкам. Для загрузки необходимо перейти по [ссылке](https://commonvoice.mozilla.org/en/datasets) и, получив доступ в соответствующей форме, скачать архивы для каждого доступного языка. Внимание! Представленная разметка актуальна для версии исходного датасета `Common Voice Corpus 16.1`. 56 | 57 | - Колонка `audio_path` содержит уникальные названия `.mp3` файлов, полученных после скачивания соответствующего датасета. 58 | 59 | ## MLS 60 | 61 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/MLS.feather) 62 | 63 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио. 64 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link` 65 | 66 | ## Лицензия 67 | 68 | Данный датасет распространяется под [лицензией](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) `CC BY-NC-SA 4.0`. 69 | 70 | ## Цитирование 71 | 72 | ``` 73 | @misc{Silero VAD Dataset, 74 | author = {Silero Team}, 75 | title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages}, 76 | year = {2024}, 77 | publisher = {GitHub}, 78 | journal = {GitHub repository}, 79 | howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}}, 80 | email = {hello@silero.ai} 81 | } 82 | ``` 83 | 84 | [^1]: ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.`` 85 | -------------------------------------------------------------------------------- /examples/colab_record_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "bccAucKjnPHm" 7 | }, 8 | "source": [ 9 | "### Dependencies and inputs" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "cSih95WFmwgi" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "#!apt install ffmpeg\n", 21 | "!pip -q install pydub\n", 22 | "from google.colab import output\n", 23 | "from base64 import b64decode, b64encode\n", 24 | "from io import BytesIO\n", 25 | "import numpy as np\n", 26 | "from pydub import AudioSegment\n", 27 | "from IPython.display import HTML, display\n", 28 | "import torch\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "import moviepy.editor as mpe\n", 31 | "from matplotlib.animation import FuncAnimation, FFMpegWriter\n", 32 | "import matplotlib\n", 33 | "matplotlib.use('Agg')\n", 34 | "\n", 35 | "torch.set_num_threads(1)\n", 36 | "\n", 37 | "model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", 38 | " model='silero_vad',\n", 39 | " force_reload=True)\n", 40 | "\n", 41 | "def int2float(audio):\n", 42 | " samples = audio.get_array_of_samples()\n", 43 | " new_sound = audio._spawn(samples)\n", 44 | " arr = np.array(samples).astype(np.float32)\n", 45 | " arr = arr / np.abs(arr).max()\n", 46 | " return arr\n", 47 | "\n", 48 | "AUDIO_HTML = \"\"\"\n", 49 | "\n", 126 | "\"\"\"\n", 127 | "\n", 128 | "def record(sec=10):\n", 129 | " display(HTML(AUDIO_HTML))\n", 130 | " s = output.eval_js(\"data\")\n", 131 | " b = b64decode(s.split(',')[1])\n", 132 | " audio = AudioSegment.from_file(BytesIO(b))\n", 133 | " audio.export('test.mp3', format='mp3')\n", 134 | " audio = audio.set_channels(1)\n", 135 | " audio = audio.set_frame_rate(16000)\n", 136 | " audio_float = int2float(audio)\n", 137 | " audio_tens = torch.tensor(audio_float)\n", 138 | " return audio_tens\n", 139 | "\n", 140 | "def make_animation(probs, audio_duration, interval=40):\n", 141 | " fig = plt.figure(figsize=(16, 9))\n", 142 | " ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n", 143 | " line, = ax.plot([], [], lw=2)\n", 144 | " x = [i / 16000 * 512 for i in range(len(probs))]\n", 145 | " plt.xlabel('Time, seconds', fontsize=16)\n", 146 | " plt.ylabel('Speech Probability', fontsize=16)\n", 147 | "\n", 148 | " def init():\n", 149 | " plt.fill_between(x, probs, color='#064273')\n", 150 | " line.set_data([], [])\n", 151 | " line.set_color('#990000')\n", 152 | " return line,\n", 153 | "\n", 154 | " def animate(i):\n", 155 | " x = i * interval / 1000 - 0.04\n", 156 | " y = np.linspace(0, 1.02, 2)\n", 157 | "\n", 158 | " line.set_data(x, y)\n", 159 | " line.set_color('#990000')\n", 160 | " return line,\n", 161 | " anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=int(audio_duration / (interval / 1000)))\n", 162 | "\n", 163 | " f = r\"animation.mp4\"\n", 164 | " writervideo = FFMpegWriter(fps=1000/interval)\n", 165 | " anim.save(f, writer=writervideo)\n", 166 | " plt.close('all')\n", 167 | "\n", 168 | "def combine_audio(vidname, audname, outname, fps=25):\n", 169 | " my_clip = mpe.VideoFileClip(vidname, verbose=False)\n", 170 | " audio_background = mpe.AudioFileClip(audname)\n", 171 | " final_clip = my_clip.set_audio(audio_background)\n", 172 | " final_clip.write_videofile(outname,fps=fps,verbose=False)\n", 173 | "\n", 174 | "def record_make_animation():\n", 175 | " tensor = record()\n", 176 | " print('Calculating probabilities...')\n", 177 | " speech_probs = []\n", 178 | " window_size_samples = 512\n", 179 | " speech_probs = model.audio_forward(tensor, sr=16000)[0].tolist()\n", 180 | " model.reset_states()\n", 181 | " print('Making animation...')\n", 182 | " make_animation(speech_probs, len(tensor) / 16000)\n", 183 | "\n", 184 | " print('Merging your voice with animation...')\n", 185 | " combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n", 186 | " print('Done!')\n", 187 | " mp4 = open('merged.mp4','rb').read()\n", 188 | " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", 189 | " display(HTML(\"\"\"\n", 190 | " \n", 193 | " \"\"\" % data_url))\n", 194 | "\n", 195 | " return speech_probs" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "metadata": { 201 | "id": "IFVs3GvTnpB1" 202 | }, 203 | "source": [ 204 | "## Record example" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "id": "5EBjrTwiqAaQ" 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "speech_probs = record_make_animation()" 216 | ] 217 | } 218 | ], 219 | "metadata": { 220 | "colab": { 221 | "collapsed_sections": [ 222 | "bccAucKjnPHm" 223 | ], 224 | "name": "Untitled2.ipynb", 225 | "provenance": [] 226 | }, 227 | "kernelspec": { 228 | "display_name": "Python 3", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "name": "python" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 0 237 | } 238 | -------------------------------------------------------------------------------- /examples/cpp/README.md: -------------------------------------------------------------------------------- 1 | # Stream example in C++ 2 | 3 | Here's a simple example of the vad model in c++ onnxruntime. 4 | 5 | 6 | 7 | ## Requirements 8 | 9 | Code are tested in the environments bellow, feel free to try others. 10 | 11 | - WSL2 + Debian-bullseye (docker) 12 | - gcc 12.2.0 13 | - onnxruntime-linux-x64-1.12.1 14 | 15 | 16 | 17 | ## Usage 18 | 19 | 1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye` 20 | 21 | 2. Install onnxruntime-linux-x64-1.12.1 22 | 23 | - Download lib onnxruntime: 24 | 25 | `wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz` 26 | 27 | - Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1` 28 | 29 | 3. Modify wav path & Test configs in main function 30 | 31 | `wav::WavReader wav_reader("${path_to_your_wav_file}");` 32 | 33 | test sample rate, frame per ms, threshold... 34 | 35 | 4. Build with gcc and run 36 | 37 | ```bash 38 | # Build 39 | g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test 40 | 41 | # Run 42 | ./test 43 | ``` -------------------------------------------------------------------------------- /examples/cpp/wav.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Personal (Binbin Zhang) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef FRONTEND_WAV_H_ 16 | #define FRONTEND_WAV_H_ 17 | 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | #include 28 | 29 | // #include "utils/log.h" 30 | 31 | namespace wav { 32 | 33 | struct WavHeader { 34 | char riff[4]; // "riff" 35 | unsigned int size; 36 | char wav[4]; // "WAVE" 37 | char fmt[4]; // "fmt " 38 | unsigned int fmt_size; 39 | uint16_t format; 40 | uint16_t channels; 41 | unsigned int sample_rate; 42 | unsigned int bytes_per_second; 43 | uint16_t block_size; 44 | uint16_t bit; 45 | char data[4]; // "data" 46 | unsigned int data_size; 47 | }; 48 | 49 | class WavReader { 50 | public: 51 | WavReader() : data_(nullptr) {} 52 | explicit WavReader(const std::string& filename) { Open(filename); } 53 | 54 | bool Open(const std::string& filename) { 55 | FILE* fp = fopen(filename.c_str(), "rb"); //文件读取 56 | if (NULL == fp) { 57 | std::cout << "Error in read " << filename; 58 | return false; 59 | } 60 | 61 | WavHeader header; 62 | fread(&header, 1, sizeof(header), fp); 63 | if (header.fmt_size < 16) { 64 | printf("WaveData: expect PCM format data " 65 | "to have fmt chunk of at least size 16.\n"); 66 | return false; 67 | } else if (header.fmt_size > 16) { 68 | int offset = 44 - 8 + header.fmt_size - 16; 69 | fseek(fp, offset, SEEK_SET); 70 | fread(header.data, 8, sizeof(char), fp); 71 | } 72 | // check "riff" "WAVE" "fmt " "data" 73 | 74 | // Skip any sub-chunks between "fmt" and "data". Usually there will 75 | // be a single "fact" sub chunk, but on Windows there can also be a 76 | // "list" sub chunk. 77 | while (0 != strncmp(header.data, "data", 4)) { 78 | // We will just ignore the data in these chunks. 79 | fseek(fp, header.data_size, SEEK_CUR); 80 | // read next sub chunk 81 | fread(header.data, 8, sizeof(char), fp); 82 | } 83 | 84 | if (header.data_size == 0) { 85 | int offset = ftell(fp); 86 | fseek(fp, 0, SEEK_END); 87 | header.data_size = ftell(fp) - offset; 88 | fseek(fp, offset, SEEK_SET); 89 | } 90 | 91 | num_channel_ = header.channels; 92 | sample_rate_ = header.sample_rate; 93 | bits_per_sample_ = header.bit; 94 | int num_data = header.data_size / (bits_per_sample_ / 8); 95 | data_ = new float[num_data]; // Create 1-dim array 96 | num_samples_ = num_data / num_channel_; 97 | 98 | std::cout << "num_channel_ :" << num_channel_ << std::endl; 99 | std::cout << "sample_rate_ :" << sample_rate_ << std::endl; 100 | std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl; 101 | std::cout << "num_samples :" << num_data << std::endl; 102 | std::cout << "num_data_size :" << header.data_size << std::endl; 103 | 104 | switch (bits_per_sample_) { 105 | case 8: { 106 | char sample; 107 | for (int i = 0; i < num_data; ++i) { 108 | fread(&sample, 1, sizeof(char), fp); 109 | data_[i] = static_cast(sample) / 32768; 110 | } 111 | break; 112 | } 113 | case 16: { 114 | int16_t sample; 115 | for (int i = 0; i < num_data; ++i) { 116 | fread(&sample, 1, sizeof(int16_t), fp); 117 | data_[i] = static_cast(sample) / 32768; 118 | } 119 | break; 120 | } 121 | case 32: 122 | { 123 | if (header.format == 1) //S32 124 | { 125 | int sample; 126 | for (int i = 0; i < num_data; ++i) { 127 | fread(&sample, 1, sizeof(int), fp); 128 | data_[i] = static_cast(sample) / 32768; 129 | } 130 | } 131 | else if (header.format == 3) // IEEE-float 132 | { 133 | float sample; 134 | for (int i = 0; i < num_data; ++i) { 135 | fread(&sample, 1, sizeof(float), fp); 136 | data_[i] = static_cast(sample); 137 | } 138 | } 139 | else { 140 | printf("unsupported quantization bits\n"); 141 | } 142 | break; 143 | } 144 | default: 145 | printf("unsupported quantization bits\n"); 146 | break; 147 | } 148 | 149 | fclose(fp); 150 | return true; 151 | } 152 | 153 | int num_channel() const { return num_channel_; } 154 | int sample_rate() const { return sample_rate_; } 155 | int bits_per_sample() const { return bits_per_sample_; } 156 | int num_samples() const { return num_samples_; } 157 | 158 | ~WavReader() { 159 | delete[] data_; 160 | } 161 | 162 | const float* data() const { return data_; } 163 | 164 | private: 165 | int num_channel_; 166 | int sample_rate_; 167 | int bits_per_sample_; 168 | int num_samples_; // sample points per channel 169 | float* data_; 170 | }; 171 | 172 | class WavWriter { 173 | public: 174 | WavWriter(const float* data, int num_samples, int num_channel, 175 | int sample_rate, int bits_per_sample) 176 | : data_(data), 177 | num_samples_(num_samples), 178 | num_channel_(num_channel), 179 | sample_rate_(sample_rate), 180 | bits_per_sample_(bits_per_sample) {} 181 | 182 | void Write(const std::string& filename) { 183 | FILE* fp = fopen(filename.c_str(), "w"); 184 | // init char 'riff' 'WAVE' 'fmt ' 'data' 185 | WavHeader header; 186 | char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 187 | 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 188 | 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 189 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 190 | 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; 191 | memcpy(&header, wav_header, sizeof(header)); 192 | header.channels = num_channel_; 193 | header.bit = bits_per_sample_; 194 | header.sample_rate = sample_rate_; 195 | header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); 196 | header.size = sizeof(header) - 8 + header.data_size; 197 | header.bytes_per_second = 198 | sample_rate_ * num_channel_ * (bits_per_sample_ / 8); 199 | header.block_size = num_channel_ * (bits_per_sample_ / 8); 200 | 201 | fwrite(&header, 1, sizeof(header), fp); 202 | 203 | for (int i = 0; i < num_samples_; ++i) { 204 | for (int j = 0; j < num_channel_; ++j) { 205 | switch (bits_per_sample_) { 206 | case 8: { 207 | char sample = static_cast(data_[i * num_channel_ + j]); 208 | fwrite(&sample, 1, sizeof(sample), fp); 209 | break; 210 | } 211 | case 16: { 212 | int16_t sample = static_cast(data_[i * num_channel_ + j]); 213 | fwrite(&sample, 1, sizeof(sample), fp); 214 | break; 215 | } 216 | case 32: { 217 | int sample = static_cast(data_[i * num_channel_ + j]); 218 | fwrite(&sample, 1, sizeof(sample), fp); 219 | break; 220 | } 221 | } 222 | } 223 | } 224 | fclose(fp); 225 | } 226 | 227 | private: 228 | const float* data_; 229 | int num_samples_; // total float points in data_ 230 | int num_channel_; 231 | int sample_rate_; 232 | int bits_per_sample_; 233 | }; 234 | 235 | } // namespace wav 236 | 237 | #endif // FRONTEND_WAV_H_ 238 | -------------------------------------------------------------------------------- /examples/cpp_libtorch/README.md: -------------------------------------------------------------------------------- 1 | # Silero-VAD V5 in C++ (based on LibTorch) 2 | 3 | This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested. 4 | 5 | Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference. 6 | 7 | ## Requirements 8 | 9 | - GCC 11.4.0 (GCC >= 5.1) 10 | - LibTorch 1.13.0 (other versions are also acceptable) 11 | 12 | ## Download LibTorch 13 | 14 | ```bash 15 | -CPU Version 16 | wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip 17 | unzip libtorch-shared-with-deps-1.13.0+cpu.zip' 18 | 19 | -CUDA Version 20 | wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip 21 | unzip libtorch-shared-with-deps-1.13.0+cu116.zip 22 | ``` 23 | 24 | ## Compilation 25 | 26 | ```bash 27 | -CPU Version 28 | g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 29 | 30 | -CUDA Version 31 | g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU 32 | ``` 33 | 34 | 35 | ## Optional Compilation Flags 36 | -DUSE_BATCH: Enable batch inference 37 | -DUSE_GPU: Use GPU for inference 38 | 39 | ## Run the Program 40 | To run the program, use the following command: 41 | 42 | `./silero aepyx.wav 16000 0.5` 43 | 44 | The sample file aepyx.wav is part of the Voxconverse dataset. 45 | File details: aepyx.wav is a 16kHz, 16-bit audio file. 46 | -------------------------------------------------------------------------------- /examples/cpp_libtorch/aepyx.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/examples/cpp_libtorch/aepyx.wav -------------------------------------------------------------------------------- /examples/cpp_libtorch/main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "silero_torch.h" 3 | #include "wav.h" 4 | 5 | int main(int argc, char* argv[]) { 6 | 7 | if(argc != 4){ 8 | std::cerr<<"Usage : "< "< input_wav(wav_reader.num_samples()); 32 | 33 | for (int i = 0; i < wav_reader.num_samples(); i++) 34 | { 35 | input_wav[i] = static_cast(*(wav_reader.data() + i)); 36 | } 37 | 38 | vad.SpeechProbs(input_wav); 39 | 40 | std::vector speeches = vad.GetSpeechTimestamps(); 41 | for(const auto& speech : speeches){ 42 | if(vad.print_as_samples){ 43 | std::cout<<"{'start': "<(speech.start)<<", 'end': "<(speech.end)<<"}"<& input_wav){ 22 | // Set the sample rate (must match the model's expected sample rate) 23 | // Process the waveform in chunks of 512 samples 24 | int num_samples = input_wav.size(); 25 | int num_chunks = num_samples / window_size_samples; 26 | int remainder_samples = num_samples % window_size_samples; 27 | 28 | total_sample_size += num_samples; 29 | 30 | torch::Tensor output; 31 | std::vector chunks; 32 | 33 | for (int i = 0; i < num_chunks; i++) { 34 | 35 | float* chunk_start = input_wav.data() + i *window_size_samples; 36 | torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32); 37 | //std::cout<<"chunk size : "<0){//마지막 chunk && 나머지가 존재 42 | int remaining_samples = num_samples - num_chunks * window_size_samples; 43 | //std::cout<<"Remainder size : "< inputs; 69 | inputs.push_back(batched_chunks); // Batch of chunks 70 | inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model 71 | 72 | // Run inference on the batch 73 | torch::NoGradGuard no_grad; 74 | torch::Tensor output = model.forward(inputs).toTensor(); 75 | #ifdef USE_GPU 76 | output = output.to(at::kCPU); // Move the output back to CPU once 77 | #endif 78 | // Collect output probabilities 79 | for (int i = 0; i < chunks.size(); i++) { 80 | float output_f = output[i].item(); 81 | outputs_prob.push_back(output_f); 82 | //std::cout << "Chunk " << i << " prob: " << output_f<< "\n"; 83 | } 84 | #else 85 | 86 | std::vector outputs; 87 | torch::Tensor batched_chunks = torch::stack(chunks); 88 | #ifdef USE_GPU 89 | batched_chunks = batched_chunks.to(at::kCUDA); 90 | #endif 91 | for (int i = 0; i < chunks.size(); i++) { 92 | torch::NoGradGuard no_grad; 93 | std::vector inputs; 94 | inputs.push_back(batched_chunks[i]); 95 | inputs.push_back(sample_rate); 96 | 97 | torch::Tensor output = model.forward(inputs).toTensor(); 98 | outputs.push_back(output); 99 | } 100 | torch::Tensor all_outputs = torch::stack(outputs); 101 | #ifdef USE_GPU 102 | all_outputs = all_outputs.to(at::kCPU); 103 | #endif 104 | for (int i = 0; i < chunks.size(); i++) { 105 | float output_f = all_outputs[i].item(); 106 | outputs_prob.push_back(output_f); 107 | } 108 | 109 | 110 | 111 | #endif 112 | 113 | } 114 | 115 | 116 | } 117 | 118 | 119 | std::vector VadIterator::GetSpeechTimestamps() { 120 | std::vector speeches = DoVad(); 121 | 122 | #ifdef USE_BATCH 123 | //When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp. 124 | //It could be better get reasonable output because of distorted probs. 125 | duration_merge_samples = sample_rate * max_duration_merge_ms / 1000; 126 | std::vector speeches_merge = mergeSpeeches(speeches, duration_merge_samples); 127 | if(!print_as_samples){ 128 | for (auto& speech : speeches_merge) { //samples to second 129 | speech.start /= sample_rate; 130 | speech.end /= sample_rate; 131 | } 132 | } 133 | 134 | return speeches_merge; 135 | #else 136 | 137 | if(!print_as_samples){ 138 | for (auto& speech : speeches) { //samples to second 139 | speech.start /= sample_rate; 140 | speech.end /= sample_rate; 141 | } 142 | } 143 | 144 | return speeches; 145 | 146 | #endif 147 | 148 | } 149 | void VadIterator::SetVariables(){ 150 | init_engine(window_size_ms); 151 | } 152 | 153 | void VadIterator::init_engine(int window_size_ms) { 154 | min_silence_samples = sample_rate * min_silence_duration_ms / 1000; 155 | speech_pad_samples = sample_rate * speech_pad_ms / 1000; 156 | window_size_samples = sample_rate / 1000 * window_size_ms; 157 | min_speech_samples = sample_rate * min_speech_duration_ms / 1000; 158 | } 159 | 160 | void VadIterator::init_torch_model(const std::string& model_path) { 161 | at::set_num_threads(1); 162 | model = torch::jit::load(model_path); 163 | 164 | #ifdef USE_GPU 165 | if (!torch::cuda::is_available()) { 166 | std::cout<<"CUDA is not available! Please check your GPU settings"< VadIterator::DoVad() { 192 | std::vector speeches; 193 | 194 | for (size_t i = 0; i < outputs_prob.size(); ++i) { 195 | float speech_prob = outputs_prob[i]; 196 | //std::cout << speech_prob << std::endl; 197 | //std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n"; 198 | //std::cout << speech_prob << " "; 199 | current_sample += window_size_samples; 200 | 201 | if (speech_prob >= threshold && temp_end != 0) { 202 | temp_end = 0; 203 | } 204 | 205 | if (speech_prob >= threshold && !triggered) { 206 | triggered = true; 207 | SpeechSegment segment; 208 | segment.start = std::max(static_cast(0), current_sample - speech_pad_samples - window_size_samples); 209 | speeches.push_back(segment); 210 | continue; 211 | } 212 | 213 | if (speech_prob < threshold - 0.15f && triggered) { 214 | if (temp_end == 0) { 215 | temp_end = current_sample; 216 | } 217 | 218 | if (current_sample - temp_end < min_silence_samples) { 219 | continue; 220 | } else { 221 | SpeechSegment& segment = speeches.back(); 222 | segment.end = temp_end + speech_pad_samples - window_size_samples; 223 | temp_end = 0; 224 | triggered = false; 225 | } 226 | } 227 | } 228 | 229 | if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나? 230 | std::cout<<"when last triggered is keep working until last Probs"<speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples); 242 | //min_speech_samples is 4000samples(0.25sec) 243 | //여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함. 244 | } 245 | ), 246 | speeches.end() 247 | ); 248 | 249 | 250 | //std::cout< VadIterator::mergeSpeeches(const std::vector& speeches, int duration_merge_samples) { 258 | std::vector mergedSpeeches; 259 | 260 | if (speeches.empty()) { 261 | return mergedSpeeches; // 빈 벡터 반환 262 | } 263 | 264 | // 첫 번째 구간으로 초기화 265 | SpeechSegment currentSegment = speeches[0]; 266 | 267 | for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터 268 | // 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침 269 | if (speeches[i].start - currentSegment.end < duration_merge_samples) { 270 | // 현재 구간의 끝점을 업데이트 271 | currentSegment.end = speeches[i].end; 272 | } else { 273 | // 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작 274 | mergedSpeeches.push_back(currentSegment); 275 | currentSegment = speeches[i]; 276 | } 277 | } 278 | 279 | // 마지막 구간 추가 280 | mergedSpeeches.push_back(currentSegment); 281 | 282 | return mergedSpeeches; 283 | } 284 | 285 | } 286 | -------------------------------------------------------------------------------- /examples/cpp_libtorch/silero_torch.h: -------------------------------------------------------------------------------- 1 | //Author : Nathan Lee 2 | //Created On : 2024-11-18 3 | //Description : silero 5.1 system for torch-script(c++). 4 | //Version : 1.0 5 | 6 | #ifndef SILERO_TORCH_H 7 | #define SILERO_TORCH_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | namespace silero{ 23 | 24 | struct SpeechSegment{ 25 | int start; 26 | int end; 27 | }; 28 | 29 | class VadIterator{ 30 | public: 31 | 32 | VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000, 33 | int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100, 34 | int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false); 35 | ~VadIterator(); 36 | 37 | 38 | void SpeechProbs(std::vector& input_wav); 39 | std::vector GetSpeechTimestamps(); 40 | void SetVariables(); 41 | 42 | float threshold; 43 | int sample_rate; 44 | int window_size_ms; 45 | int min_speech_duration_ms; 46 | int max_duration_merge_ms; 47 | bool print_as_samples; 48 | 49 | private: 50 | torch::jit::script::Module model; 51 | std::vector outputs_prob; 52 | int min_silence_samples; 53 | int min_speech_samples; 54 | int speech_pad_samples; 55 | int window_size_samples; 56 | int duration_merge_samples; 57 | int current_sample = 0; 58 | 59 | int total_sample_size=0; 60 | 61 | int min_silence_duration_ms; 62 | int speech_pad_ms; 63 | bool triggered = false; 64 | int temp_end = 0; 65 | 66 | void init_engine(int window_size_ms); 67 | void init_torch_model(const std::string& model_path); 68 | void reset_states(); 69 | std::vector DoVad(); 70 | std::vector mergeSpeeches(const std::vector& speeches, int duration_merge_samples); 71 | 72 | }; 73 | 74 | } 75 | #endif // SILERO_TORCH_H 76 | -------------------------------------------------------------------------------- /examples/cpp_libtorch/wav.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Personal (Binbin Zhang) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | 16 | #ifndef FRONTEND_WAV_H_ 17 | #define FRONTEND_WAV_H_ 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #include 26 | 27 | // #include "utils/log.h" 28 | 29 | namespace wav { 30 | 31 | struct WavHeader { 32 | char riff[4]; // "riff" 33 | unsigned int size; 34 | char wav[4]; // "WAVE" 35 | char fmt[4]; // "fmt " 36 | unsigned int fmt_size; 37 | uint16_t format; 38 | uint16_t channels; 39 | unsigned int sample_rate; 40 | unsigned int bytes_per_second; 41 | uint16_t block_size; 42 | uint16_t bit; 43 | char data[4]; // "data" 44 | unsigned int data_size; 45 | }; 46 | 47 | class WavReader { 48 | public: 49 | WavReader() : data_(nullptr) {} 50 | explicit WavReader(const std::string& filename) { Open(filename); } 51 | 52 | bool Open(const std::string& filename) { 53 | FILE* fp = fopen(filename.c_str(), "rb"); //文件读取 54 | if (NULL == fp) { 55 | std::cout << "Error in read " << filename; 56 | return false; 57 | } 58 | 59 | WavHeader header; 60 | fread(&header, 1, sizeof(header), fp); 61 | if (header.fmt_size < 16) { 62 | printf("WaveData: expect PCM format data " 63 | "to have fmt chunk of at least size 16.\n"); 64 | return false; 65 | } else if (header.fmt_size > 16) { 66 | int offset = 44 - 8 + header.fmt_size - 16; 67 | fseek(fp, offset, SEEK_SET); 68 | fread(header.data, 8, sizeof(char), fp); 69 | } 70 | // check "riff" "WAVE" "fmt " "data" 71 | 72 | // Skip any sub-chunks between "fmt" and "data". Usually there will 73 | // be a single "fact" sub chunk, but on Windows there can also be a 74 | // "list" sub chunk. 75 | while (0 != strncmp(header.data, "data", 4)) { 76 | // We will just ignore the data in these chunks. 77 | fseek(fp, header.data_size, SEEK_CUR); 78 | // read next sub chunk 79 | fread(header.data, 8, sizeof(char), fp); 80 | } 81 | 82 | if (header.data_size == 0) { 83 | int offset = ftell(fp); 84 | fseek(fp, 0, SEEK_END); 85 | header.data_size = ftell(fp) - offset; 86 | fseek(fp, offset, SEEK_SET); 87 | } 88 | 89 | num_channel_ = header.channels; 90 | sample_rate_ = header.sample_rate; 91 | bits_per_sample_ = header.bit; 92 | int num_data = header.data_size / (bits_per_sample_ / 8); 93 | data_ = new float[num_data]; // Create 1-dim array 94 | num_samples_ = num_data / num_channel_; 95 | 96 | std::cout << "num_channel_ :" << num_channel_ << std::endl; 97 | std::cout << "sample_rate_ :" << sample_rate_ << std::endl; 98 | std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl; 99 | std::cout << "num_samples :" << num_data << std::endl; 100 | std::cout << "num_data_size :" << header.data_size << std::endl; 101 | 102 | switch (bits_per_sample_) { 103 | case 8: { 104 | char sample; 105 | for (int i = 0; i < num_data; ++i) { 106 | fread(&sample, 1, sizeof(char), fp); 107 | data_[i] = static_cast(sample) / 32768; 108 | } 109 | break; 110 | } 111 | case 16: { 112 | int16_t sample; 113 | for (int i = 0; i < num_data; ++i) { 114 | fread(&sample, 1, sizeof(int16_t), fp); 115 | data_[i] = static_cast(sample) / 32768; 116 | } 117 | break; 118 | } 119 | case 32: 120 | { 121 | if (header.format == 1) //S32 122 | { 123 | int sample; 124 | for (int i = 0; i < num_data; ++i) { 125 | fread(&sample, 1, sizeof(int), fp); 126 | data_[i] = static_cast(sample) / 32768; 127 | } 128 | } 129 | else if (header.format == 3) // IEEE-float 130 | { 131 | float sample; 132 | for (int i = 0; i < num_data; ++i) { 133 | fread(&sample, 1, sizeof(float), fp); 134 | data_[i] = static_cast(sample); 135 | } 136 | } 137 | else { 138 | printf("unsupported quantization bits\n"); 139 | } 140 | break; 141 | } 142 | default: 143 | printf("unsupported quantization bits\n"); 144 | break; 145 | } 146 | 147 | fclose(fp); 148 | return true; 149 | } 150 | 151 | int num_channel() const { return num_channel_; } 152 | int sample_rate() const { return sample_rate_; } 153 | int bits_per_sample() const { return bits_per_sample_; } 154 | int num_samples() const { return num_samples_; } 155 | 156 | ~WavReader() { 157 | delete[] data_; 158 | } 159 | 160 | const float* data() const { return data_; } 161 | 162 | private: 163 | int num_channel_; 164 | int sample_rate_; 165 | int bits_per_sample_; 166 | int num_samples_; // sample points per channel 167 | float* data_; 168 | }; 169 | 170 | class WavWriter { 171 | public: 172 | WavWriter(const float* data, int num_samples, int num_channel, 173 | int sample_rate, int bits_per_sample) 174 | : data_(data), 175 | num_samples_(num_samples), 176 | num_channel_(num_channel), 177 | sample_rate_(sample_rate), 178 | bits_per_sample_(bits_per_sample) {} 179 | 180 | void Write(const std::string& filename) { 181 | FILE* fp = fopen(filename.c_str(), "w"); 182 | // init char 'riff' 'WAVE' 'fmt ' 'data' 183 | WavHeader header; 184 | char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 185 | 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 186 | 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 187 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 188 | 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; 189 | memcpy(&header, wav_header, sizeof(header)); 190 | header.channels = num_channel_; 191 | header.bit = bits_per_sample_; 192 | header.sample_rate = sample_rate_; 193 | header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); 194 | header.size = sizeof(header) - 8 + header.data_size; 195 | header.bytes_per_second = 196 | sample_rate_ * num_channel_ * (bits_per_sample_ / 8); 197 | header.block_size = num_channel_ * (bits_per_sample_ / 8); 198 | 199 | fwrite(&header, 1, sizeof(header), fp); 200 | 201 | for (int i = 0; i < num_samples_; ++i) { 202 | for (int j = 0; j < num_channel_; ++j) { 203 | switch (bits_per_sample_) { 204 | case 8: { 205 | char sample = static_cast(data_[i * num_channel_ + j]); 206 | fwrite(&sample, 1, sizeof(sample), fp); 207 | break; 208 | } 209 | case 16: { 210 | int16_t sample = static_cast(data_[i * num_channel_ + j]); 211 | fwrite(&sample, 1, sizeof(sample), fp); 212 | break; 213 | } 214 | case 32: { 215 | int sample = static_cast(data_[i * num_channel_ + j]); 216 | fwrite(&sample, 1, sizeof(sample), fp); 217 | break; 218 | } 219 | } 220 | } 221 | } 222 | fclose(fp); 223 | } 224 | 225 | private: 226 | const float* data_; 227 | int num_samples_; // total float points in data_ 228 | int num_channel_; 229 | int sample_rate_; 230 | int bits_per_sample_; 231 | }; 232 | 233 | } // namespace wenet 234 | 235 | #endif // FRONTEND_WAV_H_ 236 | -------------------------------------------------------------------------------- /examples/csharp/Program.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | 3 | namespace VadDotNet; 4 | 5 | 6 | class Program 7 | { 8 | private const string MODEL_PATH = "./resources/silero_vad.onnx"; 9 | private const string EXAMPLE_WAV_FILE = "./resources/example.wav"; 10 | private const int SAMPLE_RATE = 16000; 11 | private const float THRESHOLD = 0.5f; 12 | private const int MIN_SPEECH_DURATION_MS = 250; 13 | private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity; 14 | private const int MIN_SILENCE_DURATION_MS = 100; 15 | private const int SPEECH_PAD_MS = 30; 16 | 17 | public static void Main(string[] args) 18 | { 19 | 20 | var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE, 21 | MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); 22 | List speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE)); 23 | //Console.WriteLine(speechTimeList.ToJson()); 24 | StringBuilder sb = new StringBuilder(); 25 | foreach (var speechSegment in speechTimeList) 26 | { 27 | sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n"); 28 | 29 | } 30 | Console.WriteLine(sb.ToString()); 31 | 32 | } 33 | 34 | 35 | } 36 | -------------------------------------------------------------------------------- /examples/csharp/SileroSpeechSegment.cs: -------------------------------------------------------------------------------- 1 | namespace VadDotNet; 2 | 3 | public class SileroSpeechSegment 4 | { 5 | public int? StartOffset { get; set; } 6 | public int? EndOffset { get; set; } 7 | public float? StartSecond { get; set; } 8 | public float? EndSecond { get; set; } 9 | 10 | public SileroSpeechSegment() 11 | { 12 | } 13 | 14 | public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond) 15 | { 16 | StartOffset = startOffset; 17 | EndOffset = endOffset; 18 | StartSecond = startSecond; 19 | EndSecond = endSecond; 20 | } 21 | } -------------------------------------------------------------------------------- /examples/csharp/SileroVadDetector.cs: -------------------------------------------------------------------------------- 1 | using NAudio.Wave; 2 | using VADdotnet; 3 | 4 | namespace VadDotNet; 5 | 6 | public class SileroVadDetector 7 | { 8 | private readonly SileroVadOnnxModel _model; 9 | private readonly float _threshold; 10 | private readonly float _negThreshold; 11 | private readonly int _samplingRate; 12 | private readonly int _windowSizeSample; 13 | private readonly float _minSpeechSamples; 14 | private readonly float _speechPadSamples; 15 | private readonly float _maxSpeechSamples; 16 | private readonly float _minSilenceSamples; 17 | private readonly float _minSilenceSamplesAtMaxSpeech; 18 | private int _audioLengthSamples; 19 | private const float THRESHOLD_GAP = 0.15f; 20 | // ReSharper disable once InconsistentNaming 21 | private const int SAMPLING_RATE_8K = 8000; 22 | // ReSharper disable once InconsistentNaming 23 | private const int SAMPLING_RATE_16K = 16000; 24 | 25 | public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate, 26 | int minSpeechDurationMs, float maxSpeechDurationSeconds, 27 | int minSilenceDurationMs, int speechPadMs) 28 | { 29 | if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) 30 | { 31 | throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]"); 32 | } 33 | 34 | this._model = new SileroVadOnnxModel(onnxModelPath); 35 | this._samplingRate = samplingRate; 36 | this._threshold = threshold; 37 | this._negThreshold = threshold - THRESHOLD_GAP; 38 | this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256; 39 | this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f; 40 | this._speechPadSamples = samplingRate * speechPadMs / 1000f; 41 | this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples; 42 | this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; 43 | this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f; 44 | this.Reset(); 45 | } 46 | 47 | public void Reset() 48 | { 49 | _model.ResetStates(); 50 | } 51 | 52 | public List GetSpeechSegmentList(FileInfo wavFile) 53 | { 54 | Reset(); 55 | 56 | using (var audioFile = new AudioFileReader(wavFile.FullName)) 57 | { 58 | List speechProbList = new List(); 59 | this._audioLengthSamples = (int)(audioFile.Length / 2); 60 | float[] buffer = new float[this._windowSizeSample]; 61 | 62 | while (audioFile.Read(buffer, 0, buffer.Length) > 0) 63 | { 64 | float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0]; 65 | speechProbList.Add(speechProb); 66 | } 67 | 68 | return CalculateProb(speechProbList); 69 | } 70 | } 71 | 72 | private List CalculateProb(List speechProbList) 73 | { 74 | List result = new List(); 75 | bool triggered = false; 76 | int tempEnd = 0, prevEnd = 0, nextStart = 0; 77 | SileroSpeechSegment segment = new SileroSpeechSegment(); 78 | 79 | for (int i = 0; i < speechProbList.Count; i++) 80 | { 81 | float speechProb = speechProbList[i]; 82 | if (speechProb >= _threshold && (tempEnd != 0)) 83 | { 84 | tempEnd = 0; 85 | if (nextStart < prevEnd) 86 | { 87 | nextStart = _windowSizeSample * i; 88 | } 89 | } 90 | 91 | if (speechProb >= _threshold && !triggered) 92 | { 93 | triggered = true; 94 | segment.StartOffset = _windowSizeSample * i; 95 | continue; 96 | } 97 | 98 | if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples) 99 | { 100 | if (prevEnd != 0) 101 | { 102 | segment.EndOffset = prevEnd; 103 | result.Add(segment); 104 | segment = new SileroSpeechSegment(); 105 | if (nextStart < prevEnd) 106 | { 107 | triggered = false; 108 | } 109 | else 110 | { 111 | segment.StartOffset = nextStart; 112 | } 113 | 114 | prevEnd = 0; 115 | nextStart = 0; 116 | tempEnd = 0; 117 | } 118 | else 119 | { 120 | segment.EndOffset = _windowSizeSample * i; 121 | result.Add(segment); 122 | segment = new SileroSpeechSegment(); 123 | prevEnd = 0; 124 | nextStart = 0; 125 | tempEnd = 0; 126 | triggered = false; 127 | continue; 128 | } 129 | } 130 | 131 | if (speechProb < _negThreshold && triggered) 132 | { 133 | if (tempEnd == 0) 134 | { 135 | tempEnd = _windowSizeSample * i; 136 | } 137 | 138 | if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech) 139 | { 140 | prevEnd = tempEnd; 141 | } 142 | 143 | if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples) 144 | { 145 | continue; 146 | } 147 | else 148 | { 149 | segment.EndOffset = tempEnd; 150 | if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples) 151 | { 152 | result.Add(segment); 153 | } 154 | 155 | segment = new SileroSpeechSegment(); 156 | prevEnd = 0; 157 | nextStart = 0; 158 | tempEnd = 0; 159 | triggered = false; 160 | continue; 161 | } 162 | } 163 | } 164 | 165 | if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples) 166 | { 167 | segment.EndOffset = _audioLengthSamples; 168 | result.Add(segment); 169 | } 170 | 171 | for (int i = 0; i < result.Count; i++) 172 | { 173 | SileroSpeechSegment item = result[i]; 174 | if (i == 0) 175 | { 176 | item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples); 177 | } 178 | 179 | if (i != result.Count - 1) 180 | { 181 | SileroSpeechSegment nextItem = result[i + 1]; 182 | int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value; 183 | if (silenceDuration < 2 * _speechPadSamples) 184 | { 185 | item.EndOffset = item.EndOffset + (silenceDuration / 2); 186 | nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2)); 187 | } 188 | else 189 | { 190 | item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples); 191 | nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples); 192 | } 193 | } 194 | else 195 | { 196 | item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples); 197 | } 198 | } 199 | 200 | return MergeListAndCalculateSecond(result, _samplingRate); 201 | } 202 | 203 | private List MergeListAndCalculateSecond(List original, int samplingRate) 204 | { 205 | List result = new List(); 206 | if (original == null || original.Count == 0) 207 | { 208 | return result; 209 | } 210 | 211 | int left = original[0].StartOffset.Value; 212 | int right = original[0].EndOffset.Value; 213 | if (original.Count > 1) 214 | { 215 | original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value)); 216 | for (int i = 1; i < original.Count; i++) 217 | { 218 | SileroSpeechSegment segment = original[i]; 219 | 220 | if (segment.StartOffset > right) 221 | { 222 | result.Add(new SileroSpeechSegment(left, right, 223 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate))); 224 | left = segment.StartOffset.Value; 225 | right = segment.EndOffset.Value; 226 | } 227 | else 228 | { 229 | right = Math.Max(right, segment.EndOffset.Value); 230 | } 231 | } 232 | 233 | result.Add(new SileroSpeechSegment(left, right, 234 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate))); 235 | } 236 | else 237 | { 238 | result.Add(new SileroSpeechSegment(left, right, 239 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate))); 240 | } 241 | 242 | return result; 243 | } 244 | 245 | private float CalculateSecondByOffset(int offset, int samplingRate) 246 | { 247 | float secondValue = offset * 1.0f / samplingRate; 248 | return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f; 249 | } 250 | } -------------------------------------------------------------------------------- /examples/csharp/SileroVadOnnxModel.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.ML.OnnxRuntime; 2 | using Microsoft.ML.OnnxRuntime.Tensors; 3 | using System; 4 | using System.Collections.Generic; 5 | using System.Linq; 6 | 7 | namespace VADdotnet; 8 | 9 | 10 | public class SileroVadOnnxModel : IDisposable 11 | { 12 | private readonly InferenceSession session; 13 | private float[][][] state; 14 | private float[][] context; 15 | private int lastSr = 0; 16 | private int lastBatchSize = 0; 17 | private static readonly List SAMPLE_RATES = new List { 8000, 16000 }; 18 | 19 | public SileroVadOnnxModel(string modelPath) 20 | { 21 | var sessionOptions = new SessionOptions(); 22 | sessionOptions.InterOpNumThreads = 1; 23 | sessionOptions.IntraOpNumThreads = 1; 24 | sessionOptions.EnableCpuMemArena = true; 25 | 26 | session = new InferenceSession(modelPath, sessionOptions); 27 | ResetStates(); 28 | } 29 | 30 | public void ResetStates() 31 | { 32 | state = new float[2][][]; 33 | state[0] = new float[1][]; 34 | state[1] = new float[1][]; 35 | state[0][0] = new float[128]; 36 | state[1][0] = new float[128]; 37 | context = Array.Empty(); 38 | lastSr = 0; 39 | lastBatchSize = 0; 40 | } 41 | 42 | public void Dispose() 43 | { 44 | session?.Dispose(); 45 | } 46 | 47 | public class ValidationResult 48 | { 49 | public float[][] X { get; } 50 | public int Sr { get; } 51 | 52 | public ValidationResult(float[][] x, int sr) 53 | { 54 | X = x; 55 | Sr = sr; 56 | } 57 | } 58 | 59 | private ValidationResult ValidateInput(float[][] x, int sr) 60 | { 61 | if (x.Length == 1) 62 | { 63 | x = new float[][] { x[0] }; 64 | } 65 | if (x.Length > 2) 66 | { 67 | throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}"); 68 | } 69 | 70 | if (sr != 16000 && (sr % 16000 == 0)) 71 | { 72 | int step = sr / 16000; 73 | float[][] reducedX = new float[x.Length][]; 74 | 75 | for (int i = 0; i < x.Length; i++) 76 | { 77 | float[] current = x[i]; 78 | float[] newArr = new float[(current.Length + step - 1) / step]; 79 | 80 | for (int j = 0, index = 0; j < current.Length; j += step, index++) 81 | { 82 | newArr[index] = current[j]; 83 | } 84 | 85 | reducedX[i] = newArr; 86 | } 87 | 88 | x = reducedX; 89 | sr = 16000; 90 | } 91 | 92 | if (!SAMPLE_RATES.Contains(sr)) 93 | { 94 | throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)"); 95 | } 96 | 97 | if (((float)sr) / x[0].Length > 31.25) 98 | { 99 | throw new ArgumentException("Input audio is too short"); 100 | } 101 | 102 | return new ValidationResult(x, sr); 103 | } 104 | 105 | private static float[][] Concatenate(float[][] a, float[][] b) 106 | { 107 | if (a.Length != b.Length) 108 | { 109 | throw new ArgumentException("The number of rows in both arrays must be the same."); 110 | } 111 | 112 | int rows = a.Length; 113 | int colsA = a[0].Length; 114 | int colsB = b[0].Length; 115 | float[][] result = new float[rows][]; 116 | 117 | for (int i = 0; i < rows; i++) 118 | { 119 | result[i] = new float[colsA + colsB]; 120 | Array.Copy(a[i], 0, result[i], 0, colsA); 121 | Array.Copy(b[i], 0, result[i], colsA, colsB); 122 | } 123 | 124 | return result; 125 | } 126 | 127 | private static float[][] GetLastColumns(float[][] array, int contextSize) 128 | { 129 | int rows = array.Length; 130 | int cols = array[0].Length; 131 | 132 | if (contextSize > cols) 133 | { 134 | throw new ArgumentException("contextSize cannot be greater than the number of columns in the array."); 135 | } 136 | 137 | float[][] result = new float[rows][]; 138 | 139 | for (int i = 0; i < rows; i++) 140 | { 141 | result[i] = new float[contextSize]; 142 | Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize); 143 | } 144 | 145 | return result; 146 | } 147 | 148 | public float[] Call(float[][] x, int sr) 149 | { 150 | var result = ValidateInput(x, sr); 151 | x = result.X; 152 | sr = result.Sr; 153 | int numberSamples = sr == 16000 ? 512 : 256; 154 | 155 | if (x[0].Length != numberSamples) 156 | { 157 | throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)"); 158 | } 159 | 160 | int batchSize = x.Length; 161 | int contextSize = sr == 16000 ? 64 : 32; 162 | 163 | if (lastBatchSize == 0) 164 | { 165 | ResetStates(); 166 | } 167 | if (lastSr != 0 && lastSr != sr) 168 | { 169 | ResetStates(); 170 | } 171 | if (lastBatchSize != 0 && lastBatchSize != batchSize) 172 | { 173 | ResetStates(); 174 | } 175 | 176 | if (context.Length == 0) 177 | { 178 | context = new float[batchSize][]; 179 | for (int i = 0; i < batchSize; i++) 180 | { 181 | context[i] = new float[contextSize]; 182 | } 183 | } 184 | 185 | x = Concatenate(context, x); 186 | 187 | var inputs = new List 188 | { 189 | NamedOnnxValue.CreateFromTensor("input", new DenseTensor(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })), 190 | NamedOnnxValue.CreateFromTensor("sr", new DenseTensor(new[] { (long)sr }, new[] { 1 })), 191 | NamedOnnxValue.CreateFromTensor("state", new DenseTensor(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length })) 192 | }; 193 | 194 | using (var outputs = session.Run(inputs)) 195 | { 196 | var output = outputs.First(o => o.Name == "output").AsTensor(); 197 | var newState = outputs.First(o => o.Name == "stateN").AsTensor(); 198 | 199 | context = GetLastColumns(x, contextSize); 200 | lastSr = sr; 201 | lastBatchSize = batchSize; 202 | 203 | state = new float[newState.Dimensions[0]][][]; 204 | for (int i = 0; i < newState.Dimensions[0]; i++) 205 | { 206 | state[i] = new float[newState.Dimensions[1]][]; 207 | for (int j = 0; j < newState.Dimensions[1]; j++) 208 | { 209 | state[i][j] = new float[newState.Dimensions[2]]; 210 | for (int k = 0; k < newState.Dimensions[2]; k++) 211 | { 212 | state[i][j][k] = newState[i, j, k]; 213 | } 214 | } 215 | } 216 | 217 | return output.ToArray(); 218 | } 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /examples/csharp/VadDotNet.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net8.0 6 | enable 7 | enable 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | PreserveNewest 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /examples/csharp/resources/put_model_here.txt: -------------------------------------------------------------------------------- 1 | place onnx model file and example.wav file in this folder 2 | -------------------------------------------------------------------------------- /examples/go/README.md: -------------------------------------------------------------------------------- 1 | ## Golang Example 2 | 3 | This is a sample program of how to run speech detection using `silero-vad` from Golang (CGO + ONNX Runtime). 4 | 5 | ### Requirements 6 | 7 | - Golang >= v1.21 8 | - ONNX Runtime 9 | 10 | ### Usage 11 | 12 | ```sh 13 | go run ./cmd/main.go test.wav 14 | ``` 15 | 16 | > **_Note_** 17 | > 18 | > Make sure you have the ONNX Runtime library and C headers installed in your path. 19 | 20 | -------------------------------------------------------------------------------- /examples/go/cmd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "os" 6 | 7 | "github.com/streamer45/silero-vad-go/speech" 8 | 9 | "github.com/go-audio/wav" 10 | ) 11 | 12 | func main() { 13 | sd, err := speech.NewDetector(speech.DetectorConfig{ 14 | ModelPath: "../../src/silero_vad/data/silero_vad.onnx", 15 | SampleRate: 16000, 16 | Threshold: 0.5, 17 | MinSilenceDurationMs: 100, 18 | SpeechPadMs: 30, 19 | }) 20 | if err != nil { 21 | log.Fatalf("failed to create speech detector: %s", err) 22 | } 23 | 24 | if len(os.Args) != 2 { 25 | log.Fatalf("invalid arguments provided: expecting one file path") 26 | } 27 | 28 | f, err := os.Open(os.Args[1]) 29 | if err != nil { 30 | log.Fatalf("failed to open sample audio file: %s", err) 31 | } 32 | defer f.Close() 33 | 34 | dec := wav.NewDecoder(f) 35 | 36 | if ok := dec.IsValidFile(); !ok { 37 | log.Fatalf("invalid WAV file") 38 | } 39 | 40 | buf, err := dec.FullPCMBuffer() 41 | if err != nil { 42 | log.Fatalf("failed to get PCM buffer") 43 | } 44 | 45 | pcmBuf := buf.AsFloat32Buffer() 46 | 47 | segments, err := sd.Detect(pcmBuf.Data) 48 | if err != nil { 49 | log.Fatalf("Detect failed: %s", err) 50 | } 51 | 52 | for _, s := range segments { 53 | log.Printf("speech starts at %0.2fs", s.SpeechStartAt) 54 | if s.SpeechEndAt > 0 { 55 | log.Printf("speech ends at %0.2fs", s.SpeechEndAt) 56 | } 57 | } 58 | 59 | err = sd.Destroy() 60 | if err != nil { 61 | log.Fatalf("failed to destroy detector: %s", err) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /examples/go/go.mod: -------------------------------------------------------------------------------- 1 | module silero 2 | 3 | go 1.21.4 4 | 5 | require ( 6 | github.com/go-audio/wav v1.1.0 7 | github.com/streamer45/silero-vad-go v0.2.1 8 | ) 9 | 10 | require ( 11 | github.com/go-audio/audio v1.0.0 // indirect 12 | github.com/go-audio/riff v1.0.0 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /examples/go/go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= 4 | github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs= 5 | github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA= 6 | github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498= 7 | github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= 8 | github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= 9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 11 | github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g= 12 | github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs= 13 | github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk= 14 | github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs= 15 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 16 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 17 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 18 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 19 | -------------------------------------------------------------------------------- /examples/haskell/README.md: -------------------------------------------------------------------------------- 1 | # Haskell example 2 | 3 | To run the example, make sure you put an ``example.wav`` in this directory, and then run the following: 4 | ```bash 5 | stack run 6 | ``` 7 | 8 | The ``example.wav`` file must have the following requirements: 9 | - Must be 16khz sample rate. 10 | - Must be mono channel. 11 | - Must be 16-bit audio. 12 | 13 | This uses the [silero-vad](https://hackage.haskell.org/package/silero-vad) package, a haskell implementation based on the C# example. -------------------------------------------------------------------------------- /examples/haskell/app/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import qualified Data.Vector.Storable as Vector 4 | import Data.WAVE 5 | import Data.Function 6 | import Silero 7 | 8 | main :: IO () 9 | main = 10 | withModel $ \model -> do 11 | wav <- getWAVEFile "example.wav" 12 | let samples = 13 | concat (waveSamples wav) 14 | & Vector.fromList 15 | & Vector.map (realToFrac . sampleToDouble) 16 | let vad = 17 | (defaultVad model) 18 | { startThreshold = 0.5 19 | , endThreshold = 0.35 20 | } 21 | segments <- detectSegments vad samples 22 | print segments -------------------------------------------------------------------------------- /examples/haskell/example.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.37.0. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: example 8 | version: 0.1.0.0 9 | build-type: Simple 10 | 11 | executable example-exe 12 | main-is: Main.hs 13 | other-modules: 14 | Paths_example 15 | hs-source-dirs: 16 | app 17 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N 18 | build-depends: 19 | WAVE 20 | , base >=4.7 && <5 21 | , silero-vad 22 | , vector 23 | default-language: Haskell2010 24 | -------------------------------------------------------------------------------- /examples/haskell/package.yaml: -------------------------------------------------------------------------------- 1 | name: example 2 | version: 0.1.0.0 3 | 4 | dependencies: 5 | - base >= 4.7 && < 5 6 | - silero-vad 7 | - WAVE 8 | - vector 9 | 10 | ghc-options: 11 | - -Wall 12 | - -Wcompat 13 | - -Widentities 14 | - -Wincomplete-record-updates 15 | - -Wincomplete-uni-patterns 16 | - -Wmissing-export-lists 17 | - -Wmissing-home-modules 18 | - -Wpartial-fields 19 | - -Wredundant-constraints 20 | 21 | executables: 22 | example-exe: 23 | main: Main.hs 24 | source-dirs: app 25 | ghc-options: 26 | - -threaded 27 | - -rtsopts 28 | - -with-rtsopts=-N -------------------------------------------------------------------------------- /examples/haskell/stack.yaml: -------------------------------------------------------------------------------- 1 | snapshot: 2 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml 3 | 4 | packages: 5 | - . 6 | 7 | extra-deps: 8 | - silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476 9 | - WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016 10 | - derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313 11 | - vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481 -------------------------------------------------------------------------------- /examples/haskell/stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476 9 | pantry-tree: 10 | sha256: a62e813f978d32c87769796fded981d25fcf2875bb2afdf60ed6279f931ccd7f 11 | size: 1391 12 | original: 13 | hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476 14 | - completed: 15 | hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016 16 | pantry-tree: 17 | sha256: ee5ccd70fa7fe6ffc360ebd762b2e3f44ae10406aa27f3842d55b8cbd1a19498 18 | size: 405 19 | original: 20 | hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016 21 | - completed: 22 | hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313 23 | pantry-tree: 24 | sha256: 48e35a72d1bb593173890616c8d7efd636a650a306a50bb3e1513e679939d27e 25 | size: 902 26 | original: 27 | hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313 28 | - completed: 29 | hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481 30 | pantry-tree: 31 | sha256: 2176fd677a02a4c47337f7dca5aeca2745dbb821a6ea5c7099b3a991ecd7f4f0 32 | size: 4478 33 | original: 34 | hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481 35 | snapshots: 36 | - completed: 37 | sha256: 5a59b2a405b3aba3c00188453be172b85893cab8ebc352b1ef58b0eae5d248a2 38 | size: 650475 39 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml 40 | original: 41 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml 42 | -------------------------------------------------------------------------------- /examples/java-example/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | org.example 6 | java-example 7 | 1.0-SNAPSHOT 8 | jar 9 | 10 | sliero-vad-example 11 | http://maven.apache.org 12 | 13 | 14 | UTF-8 15 | 16 | 17 | 18 | 19 | junit 20 | junit 21 | 3.8.1 22 | test 23 | 24 | 25 | com.microsoft.onnxruntime 26 | onnxruntime 27 | 1.16.0-rc1 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /examples/java-example/src/main/java/org/example/App.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import javax.sound.sampled.*; 5 | import java.util.Map; 6 | 7 | public class App { 8 | 9 | private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx"; 10 | private static final int SAMPLE_RATE = 16000; 11 | private static final float START_THRESHOLD = 0.6f; 12 | private static final float END_THRESHOLD = 0.45f; 13 | private static final int MIN_SILENCE_DURATION_MS = 600; 14 | private static final int SPEECH_PAD_MS = 500; 15 | private static final int WINDOW_SIZE_SAMPLES = 2048; 16 | 17 | public static void main(String[] args) { 18 | // Initialize the Voice Activity Detector 19 | SlieroVadDetector vadDetector; 20 | try { 21 | vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); 22 | } catch (OrtException e) { 23 | System.err.println("Error initializing the VAD detector: " + e.getMessage()); 24 | return; 25 | } 26 | 27 | // Set audio format 28 | AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false); 29 | DataLine.Info info = new DataLine.Info(TargetDataLine.class, format); 30 | 31 | // Get the target data line and open it with the specified format 32 | TargetDataLine targetDataLine; 33 | try { 34 | targetDataLine = (TargetDataLine) AudioSystem.getLine(info); 35 | targetDataLine.open(format); 36 | targetDataLine.start(); 37 | } catch (LineUnavailableException e) { 38 | System.err.println("Error opening target data line: " + e.getMessage()); 39 | return; 40 | } 41 | 42 | // Main loop to continuously read data and apply Voice Activity Detection 43 | while (targetDataLine.isOpen()) { 44 | byte[] data = new byte[WINDOW_SIZE_SAMPLES]; 45 | 46 | int numBytesRead = targetDataLine.read(data, 0, data.length); 47 | if (numBytesRead <= 0) { 48 | System.err.println("Error reading data from target data line."); 49 | continue; 50 | } 51 | 52 | // Apply the Voice Activity Detector to the data and get the result 53 | Map detectResult; 54 | try { 55 | detectResult = vadDetector.apply(data, true); 56 | } catch (Exception e) { 57 | System.err.println("Error applying VAD detector: " + e.getMessage()); 58 | continue; 59 | } 60 | 61 | if (!detectResult.isEmpty()) { 62 | System.out.println(detectResult); 63 | } 64 | } 65 | 66 | // Close the target data line to release audio resources 67 | targetDataLine.close(); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /examples/java-example/src/main/java/org/example/SlieroVadDetector.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | import ai.onnxruntime.OrtException; 4 | 5 | import java.math.BigDecimal; 6 | import java.math.RoundingMode; 7 | import java.util.Collections; 8 | import java.util.HashMap; 9 | import java.util.Map; 10 | 11 | 12 | public class SlieroVadDetector { 13 | // OnnxModel model used for speech processing 14 | private final SlieroVadOnnxModel model; 15 | // Threshold for speech start 16 | private final float startThreshold; 17 | // Threshold for speech end 18 | private final float endThreshold; 19 | // Sampling rate 20 | private final int samplingRate; 21 | // Minimum number of silence samples to determine the end threshold of speech 22 | private final float minSilenceSamples; 23 | // Additional number of samples for speech start or end to calculate speech start or end time 24 | private final float speechPadSamples; 25 | // Whether in the triggered state (i.e. whether speech is being detected) 26 | private boolean triggered; 27 | // Temporarily stored number of speech end samples 28 | private int tempEnd; 29 | // Number of samples currently being processed 30 | private int currentSample; 31 | 32 | 33 | public SlieroVadDetector(String modelPath, 34 | float startThreshold, 35 | float endThreshold, 36 | int samplingRate, 37 | int minSilenceDurationMs, 38 | int speechPadMs) throws OrtException { 39 | // Check if the sampling rate is 8000 or 16000, if not, throw an exception 40 | if (samplingRate != 8000 && samplingRate != 16000) { 41 | throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]"); 42 | } 43 | 44 | // Initialize the parameters 45 | this.model = new SlieroVadOnnxModel(modelPath); 46 | this.startThreshold = startThreshold; 47 | this.endThreshold = endThreshold; 48 | this.samplingRate = samplingRate; 49 | this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; 50 | this.speechPadSamples = samplingRate * speechPadMs / 1000f; 51 | // Reset the state 52 | reset(); 53 | } 54 | 55 | // Method to reset the state, including the model state, trigger state, temporary end time, and current sample count 56 | public void reset() { 57 | model.resetStates(); 58 | triggered = false; 59 | tempEnd = 0; 60 | currentSample = 0; 61 | } 62 | 63 | // apply method for processing the audio array, returning possible speech start or end times 64 | public Map apply(byte[] data, boolean returnSeconds) { 65 | 66 | // Convert the byte array to a float array 67 | float[] audioData = new float[data.length / 2]; 68 | for (int i = 0; i < audioData.length; i++) { 69 | audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; 70 | } 71 | 72 | // Get the length of the audio array as the window size 73 | int windowSizeSamples = audioData.length; 74 | // Update the current sample count 75 | currentSample += windowSizeSamples; 76 | 77 | // Call the model to get the prediction probability of speech 78 | float speechProb = 0; 79 | try { 80 | speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; 81 | } catch (OrtException e) { 82 | throw new RuntimeException(e); 83 | } 84 | 85 | // If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time 86 | // This indicates that the speech duration has exceeded expectations and needs to recalculate the end time 87 | if (speechProb >= startThreshold && tempEnd != 0) { 88 | tempEnd = 0; 89 | } 90 | 91 | // If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time 92 | if (speechProb >= startThreshold && !triggered) { 93 | triggered = true; 94 | int speechStart = (int) (currentSample - speechPadSamples); 95 | speechStart = Math.max(speechStart, 0); 96 | Map result = new HashMap<>(); 97 | // Decide whether to return the result in seconds or sample count based on the returnSeconds parameter 98 | if (returnSeconds) { 99 | double speechStartSeconds = speechStart / (double) samplingRate; 100 | double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); 101 | result.put("start", roundedSpeechStart); 102 | } else { 103 | result.put("start", (double) speechStart); 104 | } 105 | 106 | return result; 107 | } 108 | 109 | // If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time 110 | if (speechProb < endThreshold && triggered) { 111 | // Initialize or update the temporary end time 112 | if (tempEnd == 0) { 113 | tempEnd = currentSample; 114 | } 115 | // If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null 116 | // This indicates that it is not yet possible to determine whether the speech has ended 117 | if (currentSample - tempEnd < minSilenceSamples) { 118 | return Collections.emptyMap(); 119 | } else { 120 | // Calculate the speech end time, reset the trigger state and temporary end time 121 | int speechEnd = (int) (tempEnd + speechPadSamples); 122 | tempEnd = 0; 123 | triggered = false; 124 | Map result = new HashMap<>(); 125 | 126 | if (returnSeconds) { 127 | double speechEndSeconds = speechEnd / (double) samplingRate; 128 | double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue(); 129 | result.put("end", roundedSpeechEnd); 130 | } else { 131 | result.put("end", (double) speechEnd); 132 | } 133 | return result; 134 | } 135 | } 136 | 137 | // If the above conditions are not met, return null by default 138 | return Collections.emptyMap(); 139 | } 140 | 141 | public void close() throws OrtException { 142 | reset(); 143 | model.close(); 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | import ai.onnxruntime.OnnxTensor; 4 | import ai.onnxruntime.OrtEnvironment; 5 | import ai.onnxruntime.OrtException; 6 | import ai.onnxruntime.OrtSession; 7 | import java.util.Arrays; 8 | import java.util.HashMap; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | public class SlieroVadOnnxModel { 13 | // Define private variable OrtSession 14 | private final OrtSession session; 15 | private float[][][] h; 16 | private float[][][] c; 17 | // Define the last sample rate 18 | private int lastSr = 0; 19 | // Define the last batch size 20 | private int lastBatchSize = 0; 21 | // Define a list of supported sample rates 22 | private static final List SAMPLE_RATES = Arrays.asList(8000, 16000); 23 | 24 | // Constructor 25 | public SlieroVadOnnxModel(String modelPath) throws OrtException { 26 | // Get the ONNX runtime environment 27 | OrtEnvironment env = OrtEnvironment.getEnvironment(); 28 | // Create an ONNX session options object 29 | OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); 30 | // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations 31 | opts.setInterOpNumThreads(1); 32 | // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation 33 | opts.setIntraOpNumThreads(1); 34 | // Add a CPU device, setting to false disables CPU execution optimization 35 | opts.addCPU(true); 36 | // Create an ONNX session using the environment, model path, and options 37 | session = env.createSession(modelPath, opts); 38 | // Reset states 39 | resetStates(); 40 | } 41 | 42 | /** 43 | * Reset states 44 | */ 45 | void resetStates() { 46 | h = new float[2][1][64]; 47 | c = new float[2][1][64]; 48 | lastSr = 0; 49 | lastBatchSize = 0; 50 | } 51 | 52 | public void close() throws OrtException { 53 | session.close(); 54 | } 55 | 56 | /** 57 | * Define inner class ValidationResult 58 | */ 59 | public static class ValidationResult { 60 | public final float[][] x; 61 | public final int sr; 62 | 63 | // Constructor 64 | public ValidationResult(float[][] x, int sr) { 65 | this.x = x; 66 | this.sr = sr; 67 | } 68 | } 69 | 70 | /** 71 | * Function to validate input data 72 | */ 73 | private ValidationResult validateInput(float[][] x, int sr) { 74 | // Process the input data with dimension 1 75 | if (x.length == 1) { 76 | x = new float[][]{x[0]}; 77 | } 78 | // Throw an exception when the input data dimension is greater than 2 79 | if (x.length > 2) { 80 | throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); 81 | } 82 | 83 | // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 84 | if (sr != 16000 && (sr % 16000 == 0)) { 85 | int step = sr / 16000; 86 | float[][] reducedX = new float[x.length][]; 87 | 88 | for (int i = 0; i < x.length; i++) { 89 | float[] current = x[i]; 90 | float[] newArr = new float[(current.length + step - 1) / step]; 91 | 92 | for (int j = 0, index = 0; j < current.length; j += step, index++) { 93 | newArr[index] = current[j]; 94 | } 95 | 96 | reducedX[i] = newArr; 97 | } 98 | 99 | x = reducedX; 100 | sr = 16000; 101 | } 102 | 103 | // If the sample rate is not in the list of supported sample rates, throw an exception 104 | if (!SAMPLE_RATES.contains(sr)) { 105 | throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); 106 | } 107 | 108 | // If the input audio block is too short, throw an exception 109 | if (((float) sr) / x[0].length > 31.25) { 110 | throw new IllegalArgumentException("Input audio is too short"); 111 | } 112 | 113 | // Return the validated result 114 | return new ValidationResult(x, sr); 115 | } 116 | 117 | /** 118 | * Method to call the ONNX model 119 | */ 120 | public float[] call(float[][] x, int sr) throws OrtException { 121 | ValidationResult result = validateInput(x, sr); 122 | x = result.x; 123 | sr = result.sr; 124 | 125 | int batchSize = x.length; 126 | 127 | if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) { 128 | resetStates(); 129 | } 130 | 131 | OrtEnvironment env = OrtEnvironment.getEnvironment(); 132 | 133 | OnnxTensor inputTensor = null; 134 | OnnxTensor hTensor = null; 135 | OnnxTensor cTensor = null; 136 | OnnxTensor srTensor = null; 137 | OrtSession.Result ortOutputs = null; 138 | 139 | try { 140 | // Create input tensors 141 | inputTensor = OnnxTensor.createTensor(env, x); 142 | hTensor = OnnxTensor.createTensor(env, h); 143 | cTensor = OnnxTensor.createTensor(env, c); 144 | srTensor = OnnxTensor.createTensor(env, new long[]{sr}); 145 | 146 | Map inputs = new HashMap<>(); 147 | inputs.put("input", inputTensor); 148 | inputs.put("sr", srTensor); 149 | inputs.put("h", hTensor); 150 | inputs.put("c", cTensor); 151 | 152 | // Call the ONNX model for calculation 153 | ortOutputs = session.run(inputs); 154 | // Get the output results 155 | float[][] output = (float[][]) ortOutputs.get(0).getValue(); 156 | h = (float[][][]) ortOutputs.get(1).getValue(); 157 | c = (float[][][]) ortOutputs.get(2).getValue(); 158 | 159 | lastSr = sr; 160 | lastBatchSize = batchSize; 161 | return output[0]; 162 | } finally { 163 | if (inputTensor != null) { 164 | inputTensor.close(); 165 | } 166 | if (hTensor != null) { 167 | hTensor.close(); 168 | } 169 | if (cTensor != null) { 170 | cTensor.close(); 171 | } 172 | if (srTensor != null) { 173 | srTensor.close(); 174 | } 175 | if (ortOutputs != null) { 176 | ortOutputs.close(); 177 | } 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /examples/java-wav-file-example/src/main/java/org/example/App.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | import ai.onnxruntime.OrtException; 4 | import java.io.File; 5 | import java.util.List; 6 | 7 | public class App { 8 | 9 | private static final String MODEL_PATH = "/path/silero_vad.onnx"; 10 | private static final String EXAMPLE_WAV_FILE = "/path/example.wav"; 11 | private static final int SAMPLE_RATE = 16000; 12 | private static final float THRESHOLD = 0.5f; 13 | private static final int MIN_SPEECH_DURATION_MS = 250; 14 | private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY; 15 | private static final int MIN_SILENCE_DURATION_MS = 100; 16 | private static final int SPEECH_PAD_MS = 30; 17 | 18 | public static void main(String[] args) { 19 | // Initialize the Voice Activity Detector 20 | SileroVadDetector vadDetector; 21 | try { 22 | vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE, 23 | MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS); 24 | fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE)); 25 | } catch (OrtException e) { 26 | System.err.println("Error initializing the VAD detector: " + e.getMessage()); 27 | } 28 | } 29 | 30 | public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) { 31 | List speechTimeList = vadDetector.getSpeechSegmentList(wavFile); 32 | for (SileroSpeechSegment speechSegment : speechTimeList) { 33 | System.out.println(String.format("start second: %f, end second: %f", 34 | speechSegment.getStartSecond(), speechSegment.getEndSecond())); 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/java-wav-file-example/src/main/java/org/example/SileroSpeechSegment.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | 4 | public class SileroSpeechSegment { 5 | private Integer startOffset; 6 | private Integer endOffset; 7 | private Float startSecond; 8 | private Float endSecond; 9 | 10 | public SileroSpeechSegment() { 11 | } 12 | 13 | public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) { 14 | this.startOffset = startOffset; 15 | this.endOffset = endOffset; 16 | this.startSecond = startSecond; 17 | this.endSecond = endSecond; 18 | } 19 | 20 | public Integer getStartOffset() { 21 | return startOffset; 22 | } 23 | 24 | public Integer getEndOffset() { 25 | return endOffset; 26 | } 27 | 28 | public Float getStartSecond() { 29 | return startSecond; 30 | } 31 | 32 | public Float getEndSecond() { 33 | return endSecond; 34 | } 35 | 36 | public void setStartOffset(Integer startOffset) { 37 | this.startOffset = startOffset; 38 | } 39 | 40 | public void setEndOffset(Integer endOffset) { 41 | this.endOffset = endOffset; 42 | } 43 | 44 | public void setStartSecond(Float startSecond) { 45 | this.startSecond = startSecond; 46 | } 47 | 48 | public void setEndSecond(Float endSecond) { 49 | this.endSecond = endSecond; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /examples/java-wav-file-example/src/main/java/org/example/SileroVadDetector.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | 4 | import ai.onnxruntime.OrtException; 5 | 6 | import javax.sound.sampled.AudioInputStream; 7 | import javax.sound.sampled.AudioSystem; 8 | import java.io.File; 9 | import java.util.ArrayList; 10 | import java.util.Comparator; 11 | import java.util.List; 12 | 13 | public class SileroVadDetector { 14 | private final SileroVadOnnxModel model; 15 | private final float threshold; 16 | private final float negThreshold; 17 | private final int samplingRate; 18 | private final int windowSizeSample; 19 | private final float minSpeechSamples; 20 | private final float speechPadSamples; 21 | private final float maxSpeechSamples; 22 | private final float minSilenceSamples; 23 | private final float minSilenceSamplesAtMaxSpeech; 24 | private int audioLengthSamples; 25 | private static final float THRESHOLD_GAP = 0.15f; 26 | private static final Integer SAMPLING_RATE_8K = 8000; 27 | private static final Integer SAMPLING_RATE_16K = 16000; 28 | 29 | /** 30 | * Constructor 31 | * @param onnxModelPath the path of silero-vad onnx model 32 | * @param threshold threshold for speech start 33 | * @param samplingRate audio sampling rate, only available for [8k, 16k] 34 | * @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech 35 | * @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY 36 | * @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence 37 | * @param speechPadMs Additional pad millis for speech start and end 38 | * @throws OrtException 39 | */ 40 | public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate, 41 | int minSpeechDurationMs, float maxSpeechDurationSeconds, 42 | int minSilenceDurationMs, int speechPadMs) throws OrtException { 43 | if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) { 44 | throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]"); 45 | } 46 | this.model = new SileroVadOnnxModel(onnxModelPath); 47 | this.samplingRate = samplingRate; 48 | this.threshold = threshold; 49 | this.negThreshold = threshold - THRESHOLD_GAP; 50 | if (samplingRate == SAMPLING_RATE_16K) { 51 | this.windowSizeSample = 512; 52 | } else { 53 | this.windowSizeSample = 256; 54 | } 55 | this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f; 56 | this.speechPadSamples = samplingRate * speechPadMs / 1000f; 57 | this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples; 58 | this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f; 59 | this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f; 60 | this.reset(); 61 | } 62 | 63 | /** 64 | * Method to reset the state 65 | */ 66 | public void reset() { 67 | model.resetStates(); 68 | } 69 | 70 | /** 71 | * Get speech segment list by given wav-format file 72 | * @param wavFile wav file 73 | * @return list of speech segment 74 | */ 75 | public List getSpeechSegmentList(File wavFile) { 76 | reset(); 77 | try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){ 78 | List speechProbList = new ArrayList<>(); 79 | this.audioLengthSamples = audioInputStream.available() / 2; 80 | byte[] data = new byte[this.windowSizeSample * 2]; 81 | int numBytesRead = 0; 82 | 83 | while ((numBytesRead = audioInputStream.read(data)) != -1) { 84 | if (numBytesRead <= 0) { 85 | break; 86 | } 87 | // Convert the byte array to a float array 88 | float[] audioData = new float[data.length / 2]; 89 | for (int i = 0; i < audioData.length; i++) { 90 | audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f; 91 | } 92 | 93 | float speechProb = 0; 94 | try { 95 | speechProb = model.call(new float[][]{audioData}, samplingRate)[0]; 96 | speechProbList.add(speechProb); 97 | } catch (OrtException e) { 98 | throw e; 99 | } 100 | } 101 | return calculateProb(speechProbList); 102 | } catch (Exception e) { 103 | throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e); 104 | } 105 | } 106 | 107 | /** 108 | * Calculate speech segement by probability 109 | * @param speechProbList speech probability list 110 | * @return list of speech segment 111 | */ 112 | private List calculateProb(List speechProbList) { 113 | List result = new ArrayList<>(); 114 | boolean triggered = false; 115 | int tempEnd = 0, prevEnd = 0, nextStart = 0; 116 | SileroSpeechSegment segment = new SileroSpeechSegment(); 117 | 118 | for (int i = 0; i < speechProbList.size(); i++) { 119 | Float speechProb = speechProbList.get(i); 120 | if (speechProb >= threshold && (tempEnd != 0)) { 121 | tempEnd = 0; 122 | if (nextStart < prevEnd) { 123 | nextStart = windowSizeSample * i; 124 | } 125 | } 126 | 127 | if (speechProb >= threshold && !triggered) { 128 | triggered = true; 129 | segment.setStartOffset(windowSizeSample * i); 130 | continue; 131 | } 132 | 133 | if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) { 134 | if (prevEnd != 0) { 135 | segment.setEndOffset(prevEnd); 136 | result.add(segment); 137 | segment = new SileroSpeechSegment(); 138 | if (nextStart < prevEnd) { 139 | triggered = false; 140 | }else { 141 | segment.setStartOffset(nextStart); 142 | } 143 | prevEnd = 0; 144 | nextStart = 0; 145 | tempEnd = 0; 146 | }else { 147 | segment.setEndOffset(windowSizeSample * i); 148 | result.add(segment); 149 | segment = new SileroSpeechSegment(); 150 | prevEnd = 0; 151 | nextStart = 0; 152 | tempEnd = 0; 153 | triggered = false; 154 | continue; 155 | } 156 | } 157 | 158 | if (speechProb < negThreshold && triggered) { 159 | if (tempEnd == 0) { 160 | tempEnd = windowSizeSample * i; 161 | } 162 | if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) { 163 | prevEnd = tempEnd; 164 | } 165 | if ((windowSizeSample * i) - tempEnd < minSilenceSamples) { 166 | continue; 167 | }else { 168 | segment.setEndOffset(tempEnd); 169 | if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) { 170 | result.add(segment); 171 | } 172 | segment = new SileroSpeechSegment(); 173 | prevEnd = 0; 174 | nextStart = 0; 175 | tempEnd = 0; 176 | triggered = false; 177 | continue; 178 | } 179 | } 180 | } 181 | 182 | if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) { 183 | segment.setEndOffset(audioLengthSamples); 184 | result.add(segment); 185 | } 186 | 187 | for (int i = 0; i < result.size(); i++) { 188 | SileroSpeechSegment item = result.get(i); 189 | if (i == 0) { 190 | item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples))); 191 | } 192 | if (i != result.size() - 1) { 193 | SileroSpeechSegment nextItem = result.get(i + 1); 194 | Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset(); 195 | if(silenceDuration < 2 * speechPadSamples){ 196 | item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 )); 197 | nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2))); 198 | } else { 199 | item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples))); 200 | nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples))); 201 | } 202 | }else { 203 | item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples))); 204 | } 205 | } 206 | 207 | return mergeListAndCalculateSecond(result, samplingRate); 208 | } 209 | 210 | private List mergeListAndCalculateSecond(List original, Integer samplingRate) { 211 | List result = new ArrayList<>(); 212 | if (original == null || original.size() == 0) { 213 | return result; 214 | } 215 | Integer left = original.get(0).getStartOffset(); 216 | Integer right = original.get(0).getEndOffset(); 217 | if (original.size() > 1) { 218 | original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset)); 219 | for (int i = 1; i < original.size(); i++) { 220 | SileroSpeechSegment segment = original.get(i); 221 | 222 | if (segment.getStartOffset() > right) { 223 | result.add(new SileroSpeechSegment(left, right, 224 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate))); 225 | left = segment.getStartOffset(); 226 | right = segment.getEndOffset(); 227 | } else { 228 | right = Math.max(right, segment.getEndOffset()); 229 | } 230 | } 231 | result.add(new SileroSpeechSegment(left, right, 232 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate))); 233 | }else { 234 | result.add(new SileroSpeechSegment(left, right, 235 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate))); 236 | } 237 | return result; 238 | } 239 | 240 | private Float calculateSecondByOffset(Integer offset, Integer samplingRate) { 241 | float secondValue = offset * 1.0f / samplingRate; 242 | return (float) Math.floor(secondValue * 1000.0f) / 1000.0f; 243 | } 244 | } 245 | -------------------------------------------------------------------------------- /examples/java-wav-file-example/src/main/java/org/example/SileroVadOnnxModel.java: -------------------------------------------------------------------------------- 1 | package org.example; 2 | 3 | import ai.onnxruntime.OnnxTensor; 4 | import ai.onnxruntime.OrtEnvironment; 5 | import ai.onnxruntime.OrtException; 6 | import ai.onnxruntime.OrtSession; 7 | import java.util.Arrays; 8 | import java.util.HashMap; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | public class SileroVadOnnxModel { 13 | // Define private variable OrtSession 14 | private final OrtSession session; 15 | private float[][][] state; 16 | private float[][] context; 17 | // Define the last sample rate 18 | private int lastSr = 0; 19 | // Define the last batch size 20 | private int lastBatchSize = 0; 21 | // Define a list of supported sample rates 22 | private static final List SAMPLE_RATES = Arrays.asList(8000, 16000); 23 | 24 | // Constructor 25 | public SileroVadOnnxModel(String modelPath) throws OrtException { 26 | // Get the ONNX runtime environment 27 | OrtEnvironment env = OrtEnvironment.getEnvironment(); 28 | // Create an ONNX session options object 29 | OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); 30 | // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations 31 | opts.setInterOpNumThreads(1); 32 | // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation 33 | opts.setIntraOpNumThreads(1); 34 | // Add a CPU device, setting to false disables CPU execution optimization 35 | opts.addCPU(true); 36 | // Create an ONNX session using the environment, model path, and options 37 | session = env.createSession(modelPath, opts); 38 | // Reset states 39 | resetStates(); 40 | } 41 | 42 | /** 43 | * Reset states 44 | */ 45 | void resetStates() { 46 | state = new float[2][1][128]; 47 | context = new float[0][]; 48 | lastSr = 0; 49 | lastBatchSize = 0; 50 | } 51 | 52 | public void close() throws OrtException { 53 | session.close(); 54 | } 55 | 56 | /** 57 | * Define inner class ValidationResult 58 | */ 59 | public static class ValidationResult { 60 | public final float[][] x; 61 | public final int sr; 62 | 63 | // Constructor 64 | public ValidationResult(float[][] x, int sr) { 65 | this.x = x; 66 | this.sr = sr; 67 | } 68 | } 69 | 70 | /** 71 | * Function to validate input data 72 | */ 73 | private ValidationResult validateInput(float[][] x, int sr) { 74 | // Process the input data with dimension 1 75 | if (x.length == 1) { 76 | x = new float[][]{x[0]}; 77 | } 78 | // Throw an exception when the input data dimension is greater than 2 79 | if (x.length > 2) { 80 | throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length); 81 | } 82 | 83 | // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000 84 | if (sr != 16000 && (sr % 16000 == 0)) { 85 | int step = sr / 16000; 86 | float[][] reducedX = new float[x.length][]; 87 | 88 | for (int i = 0; i < x.length; i++) { 89 | float[] current = x[i]; 90 | float[] newArr = new float[(current.length + step - 1) / step]; 91 | 92 | for (int j = 0, index = 0; j < current.length; j += step, index++) { 93 | newArr[index] = current[j]; 94 | } 95 | 96 | reducedX[i] = newArr; 97 | } 98 | 99 | x = reducedX; 100 | sr = 16000; 101 | } 102 | 103 | // If the sample rate is not in the list of supported sample rates, throw an exception 104 | if (!SAMPLE_RATES.contains(sr)) { 105 | throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)"); 106 | } 107 | 108 | // If the input audio block is too short, throw an exception 109 | if (((float) sr) / x[0].length > 31.25) { 110 | throw new IllegalArgumentException("Input audio is too short"); 111 | } 112 | 113 | // Return the validated result 114 | return new ValidationResult(x, sr); 115 | } 116 | 117 | private static float[][] concatenate(float[][] a, float[][] b) { 118 | if (a.length != b.length) { 119 | throw new IllegalArgumentException("The number of rows in both arrays must be the same."); 120 | } 121 | 122 | int rows = a.length; 123 | int colsA = a[0].length; 124 | int colsB = b[0].length; 125 | float[][] result = new float[rows][colsA + colsB]; 126 | 127 | for (int i = 0; i < rows; i++) { 128 | System.arraycopy(a[i], 0, result[i], 0, colsA); 129 | System.arraycopy(b[i], 0, result[i], colsA, colsB); 130 | } 131 | 132 | return result; 133 | } 134 | 135 | private static float[][] getLastColumns(float[][] array, int contextSize) { 136 | int rows = array.length; 137 | int cols = array[0].length; 138 | 139 | if (contextSize > cols) { 140 | throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array."); 141 | } 142 | 143 | float[][] result = new float[rows][contextSize]; 144 | 145 | for (int i = 0; i < rows; i++) { 146 | System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize); 147 | } 148 | 149 | return result; 150 | } 151 | 152 | /** 153 | * Method to call the ONNX model 154 | */ 155 | public float[] call(float[][] x, int sr) throws OrtException { 156 | ValidationResult result = validateInput(x, sr); 157 | x = result.x; 158 | sr = result.sr; 159 | int numberSamples = 256; 160 | if (sr == 16000) { 161 | numberSamples = 512; 162 | } 163 | 164 | if (x[0].length != numberSamples) { 165 | throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)"); 166 | } 167 | 168 | int batchSize = x.length; 169 | 170 | int contextSize = 32; 171 | if (sr == 16000) { 172 | contextSize = 64; 173 | } 174 | 175 | if (lastBatchSize == 0) { 176 | resetStates(); 177 | } 178 | if (lastSr != 0 && lastSr != sr) { 179 | resetStates(); 180 | } 181 | if (lastBatchSize != 0 && lastBatchSize != batchSize) { 182 | resetStates(); 183 | } 184 | 185 | if (context.length == 0) { 186 | context = new float[batchSize][contextSize]; 187 | } 188 | 189 | x = concatenate(context, x); 190 | 191 | OrtEnvironment env = OrtEnvironment.getEnvironment(); 192 | 193 | OnnxTensor inputTensor = null; 194 | OnnxTensor stateTensor = null; 195 | OnnxTensor srTensor = null; 196 | OrtSession.Result ortOutputs = null; 197 | 198 | try { 199 | // Create input tensors 200 | inputTensor = OnnxTensor.createTensor(env, x); 201 | stateTensor = OnnxTensor.createTensor(env, state); 202 | srTensor = OnnxTensor.createTensor(env, new long[]{sr}); 203 | 204 | Map inputs = new HashMap<>(); 205 | inputs.put("input", inputTensor); 206 | inputs.put("sr", srTensor); 207 | inputs.put("state", stateTensor); 208 | 209 | // Call the ONNX model for calculation 210 | ortOutputs = session.run(inputs); 211 | // Get the output results 212 | float[][] output = (float[][]) ortOutputs.get(0).getValue(); 213 | state = (float[][][]) ortOutputs.get(1).getValue(); 214 | 215 | context = getLastColumns(x, contextSize); 216 | lastSr = sr; 217 | lastBatchSize = batchSize; 218 | return output[0]; 219 | } finally { 220 | if (inputTensor != null) { 221 | inputTensor.close(); 222 | } 223 | if (stateTensor != null) { 224 | stateTensor.close(); 225 | } 226 | if (srTensor != null) { 227 | srTensor.close(); 228 | } 229 | if (ortOutputs != null) { 230 | ortOutputs.close(); 231 | } 232 | } 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /examples/microphone_and_webRTC_integration/README.md: -------------------------------------------------------------------------------- 1 | 2 | In this example, an integration with the microphone and the webRTC VAD has been done. I used [this](https://github.com/mozilla/DeepSpeech-examples/tree/r0.8/mic_vad_streaming) as a draft. 3 | Here a short video to present the results: 4 | 5 | https://user-images.githubusercontent.com/28188499/116685087-182ff100-a9b2-11eb-927d-ed9f621226ee.mp4 6 | 7 | # Requirements: 8 | The libraries used for the following example are: 9 | ``` 10 | Python == 3.6.9 11 | webrtcvad >= 2.0.10 12 | torchaudio >= 0.8.1 13 | torch >= 1.8.1 14 | halo >= 0.0.31 15 | Soundfile >= 0.13.3 16 | ``` 17 | Using pip3: 18 | ``` 19 | pip3 install webrtcvad 20 | pip3 install torchaudio 21 | pip3 install torch 22 | pip3 install halo 23 | pip3 install soundfile 24 | ``` 25 | Moreover, to make the code easier, the default sample_rate is 16KHz without resampling. 26 | 27 | This example has been tested on ``` ubuntu 18.04.3 LTS``` 28 | 29 | -------------------------------------------------------------------------------- /examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py: -------------------------------------------------------------------------------- 1 | import collections, queue 2 | import numpy as np 3 | import pyaudio 4 | import webrtcvad 5 | from halo import Halo 6 | import torch 7 | import torchaudio 8 | 9 | class Audio(object): 10 | """Streams raw audio from microphone. Data is received in a separate thread, and stored in a buffer, to be read from.""" 11 | 12 | FORMAT = pyaudio.paInt16 13 | # Network/VAD rate-space 14 | RATE_PROCESS = 16000 15 | CHANNELS = 1 16 | BLOCKS_PER_SECOND = 50 17 | 18 | def __init__(self, callback=None, device=None, input_rate=RATE_PROCESS): 19 | def proxy_callback(in_data, frame_count, time_info, status): 20 | #pylint: disable=unused-argument 21 | callback(in_data) 22 | return (None, pyaudio.paContinue) 23 | if callback is None: callback = lambda in_data: self.buffer_queue.put(in_data) 24 | self.buffer_queue = queue.Queue() 25 | self.device = device 26 | self.input_rate = input_rate 27 | self.sample_rate = self.RATE_PROCESS 28 | self.block_size = int(self.RATE_PROCESS / float(self.BLOCKS_PER_SECOND)) 29 | self.block_size_input = int(self.input_rate / float(self.BLOCKS_PER_SECOND)) 30 | self.pa = pyaudio.PyAudio() 31 | 32 | kwargs = { 33 | 'format': self.FORMAT, 34 | 'channels': self.CHANNELS, 35 | 'rate': self.input_rate, 36 | 'input': True, 37 | 'frames_per_buffer': self.block_size_input, 38 | 'stream_callback': proxy_callback, 39 | } 40 | 41 | self.chunk = None 42 | # if not default device 43 | if self.device: 44 | kwargs['input_device_index'] = self.device 45 | 46 | self.stream = self.pa.open(**kwargs) 47 | self.stream.start_stream() 48 | 49 | def read(self): 50 | """Return a block of audio data, blocking if necessary.""" 51 | return self.buffer_queue.get() 52 | 53 | def destroy(self): 54 | self.stream.stop_stream() 55 | self.stream.close() 56 | self.pa.terminate() 57 | 58 | frame_duration_ms = property(lambda self: 1000 * self.block_size // self.sample_rate) 59 | 60 | 61 | class VADAudio(Audio): 62 | """Filter & segment audio with voice activity detection.""" 63 | 64 | def __init__(self, aggressiveness=3, device=None, input_rate=None): 65 | super().__init__(device=device, input_rate=input_rate) 66 | self.vad = webrtcvad.Vad(aggressiveness) 67 | 68 | def frame_generator(self): 69 | """Generator that yields all audio frames from microphone.""" 70 | if self.input_rate == self.RATE_PROCESS: 71 | while True: 72 | yield self.read() 73 | else: 74 | raise Exception("Resampling required") 75 | 76 | def vad_collector(self, padding_ms=300, ratio=0.75, frames=None): 77 | """Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None. 78 | Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered. 79 | Example: (frame, ..., frame, None, frame, ..., frame, None, ...) 80 | |---utterence---| |---utterence---| 81 | """ 82 | if frames is None: frames = self.frame_generator() 83 | num_padding_frames = padding_ms // self.frame_duration_ms 84 | ring_buffer = collections.deque(maxlen=num_padding_frames) 85 | triggered = False 86 | 87 | for frame in frames: 88 | if len(frame) < 640: 89 | return 90 | 91 | is_speech = self.vad.is_speech(frame, self.sample_rate) 92 | 93 | if not triggered: 94 | ring_buffer.append((frame, is_speech)) 95 | num_voiced = len([f for f, speech in ring_buffer if speech]) 96 | if num_voiced > ratio * ring_buffer.maxlen: 97 | triggered = True 98 | for f, s in ring_buffer: 99 | yield f 100 | ring_buffer.clear() 101 | 102 | else: 103 | yield frame 104 | ring_buffer.append((frame, is_speech)) 105 | num_unvoiced = len([f for f, speech in ring_buffer if not speech]) 106 | if num_unvoiced > ratio * ring_buffer.maxlen: 107 | triggered = False 108 | yield None 109 | ring_buffer.clear() 110 | 111 | def main(ARGS): 112 | # Start audio with VAD 113 | vad_audio = VADAudio(aggressiveness=ARGS.webRTC_aggressiveness, 114 | device=ARGS.device, 115 | input_rate=ARGS.rate) 116 | 117 | print("Listening (ctrl-C to exit)...") 118 | frames = vad_audio.vad_collector() 119 | 120 | # load silero VAD 121 | torchaudio.set_audio_backend("soundfile") 122 | model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', 123 | model=ARGS.silaro_model_name, 124 | force_reload= ARGS.reload) 125 | (get_speech_ts,_,_, _,_, _, _) = utils 126 | 127 | 128 | # Stream from microphone to DeepSpeech using VAD 129 | spinner = None 130 | if not ARGS.nospinner: 131 | spinner = Halo(spinner='line') 132 | wav_data = bytearray() 133 | for frame in frames: 134 | if frame is not None: 135 | if spinner: spinner.start() 136 | 137 | wav_data.extend(frame) 138 | else: 139 | if spinner: spinner.stop() 140 | print("webRTC has detected a possible speech") 141 | 142 | newsound= np.frombuffer(wav_data,np.int16) 143 | audio_float32=Int2Float(newsound) 144 | time_stamps =get_speech_ts(audio_float32, model,num_steps=ARGS.num_steps,trig_sum=ARGS.trig_sum,neg_trig_sum=ARGS.neg_trig_sum, 145 | num_samples_per_window=ARGS.num_samples_per_window,min_speech_samples=ARGS.min_speech_samples, 146 | min_silence_samples=ARGS.min_silence_samples) 147 | 148 | if(len(time_stamps)>0): 149 | print("silero VAD has detected a possible speech") 150 | else: 151 | print("silero VAD has detected a noise") 152 | print() 153 | wav_data = bytearray() 154 | 155 | 156 | def Int2Float(sound): 157 | _sound = np.copy(sound) # 158 | abs_max = np.abs(_sound).max() 159 | _sound = _sound.astype('float32') 160 | if abs_max > 0: 161 | _sound *= 1/abs_max 162 | audio_float32 = torch.from_numpy(_sound.squeeze()) 163 | return audio_float32 164 | 165 | if __name__ == '__main__': 166 | DEFAULT_SAMPLE_RATE = 16000 167 | 168 | import argparse 169 | parser = argparse.ArgumentParser(description="Stream from microphone to webRTC and silero VAD") 170 | 171 | parser.add_argument('-v', '--webRTC_aggressiveness', type=int, default=3, 172 | help="Set aggressiveness of webRTC: an integer between 0 and 3, 0 being the least aggressive about filtering out non-speech, 3 the most aggressive. Default: 3") 173 | parser.add_argument('--nospinner', action='store_true', 174 | help="Disable spinner") 175 | parser.add_argument('-d', '--device', type=int, default=None, 176 | help="Device input index (Int) as listed by pyaudio.PyAudio.get_device_info_by_index(). If not provided, falls back to PyAudio.get_default_device().") 177 | 178 | parser.add_argument('-name', '--silaro_model_name', type=str, default="silero_vad", 179 | help="select the name of the model. You can select between 'silero_vad',''silero_vad_micro','silero_vad_micro_8k','silero_vad_mini','silero_vad_mini_8k'") 180 | parser.add_argument('--reload', action='store_true',help="download the last version of the silero vad") 181 | 182 | parser.add_argument('-ts', '--trig_sum', type=float, default=0.25, 183 | help="overlapping windows are used for each audio chunk, trig sum defines average probability among those windows for switching into triggered state (speech state)") 184 | 185 | parser.add_argument('-nts', '--neg_trig_sum', type=float, default=0.07, 186 | help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)") 187 | 188 | parser.add_argument('-N', '--num_steps', type=int, default=8, 189 | help="number of overlapping windows to split audio chunk into (we recommend 4 or 8)") 190 | 191 | parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000, 192 | help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)") 193 | 194 | parser.add_argument('-msps', '--min_speech_samples', type=int, default=10000, 195 | help="minimum speech chunk duration in samples") 196 | 197 | parser.add_argument('-msis', '--min_silence_samples', type=int, default=500, 198 | help=" minimum silence duration in samples between to separate speech chunks") 199 | ARGS = parser.parse_args() 200 | ARGS.rate=DEFAULT_SAMPLE_RATE 201 | main(ARGS) 202 | -------------------------------------------------------------------------------- /examples/parallel_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Install Dependencies" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# !pip install -q torchaudio\n", 17 | "SAMPLING_RATE = 16000\n", 18 | "import torch\n", 19 | "from pprint import pprint\n", 20 | "import time\n", 21 | "import shutil\n", 22 | "\n", 23 | "torch.set_num_threads(1)\n", 24 | "NUM_PROCESS=4 # set to the number of CPU cores in the machine\n", 25 | "NUM_COPIES=8\n", 26 | "# download wav files, make multiple copies\n", 27 | "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example0.wav\")\n", 28 | "for idx in range(NUM_COPIES-1):\n", 29 | " shutil.copy(f\"en_example0.wav\", f\"en_example{idx+1}.wav\")" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## Load VAD model from torch hub" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", 46 | " model='silero_vad',\n", 47 | " force_reload=True,\n", 48 | " onnx=False)\n", 49 | "\n", 50 | "(get_speech_timestamps,\n", 51 | "save_audio,\n", 52 | "read_audio,\n", 53 | "VADIterator,\n", 54 | "collect_chunks) = utils" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Define a vad process function" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import multiprocessing\n", 71 | "\n", 72 | "vad_models = dict()\n", 73 | "\n", 74 | "def init_model(model):\n", 75 | " pid = multiprocessing.current_process().pid\n", 76 | " model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", 77 | " model='silero_vad',\n", 78 | " force_reload=False,\n", 79 | " onnx=False)\n", 80 | " vad_models[pid] = model\n", 81 | "\n", 82 | "def vad_process(audio_file: str):\n", 83 | " \n", 84 | " pid = multiprocessing.current_process().pid\n", 85 | " \n", 86 | " with torch.no_grad():\n", 87 | " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n", 88 | " return get_speech_timestamps(\n", 89 | " wav,\n", 90 | " vad_models[pid],\n", 91 | " 0.46, # speech prob threshold\n", 92 | " 16000, # sample rate\n", 93 | " 300, # min speech duration in ms\n", 94 | " 20, # max speech duration in seconds\n", 95 | " 600, # min silence duration\n", 96 | " 512, # window size\n", 97 | " 200, # spech pad ms\n", 98 | " )" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## Parallelization" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "from concurrent.futures import ProcessPoolExecutor, as_completed\n", 115 | "\n", 116 | "futures = []\n", 117 | "\n", 118 | "with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n", 119 | " for i in range(NUM_COPIES):\n", 120 | " futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n", 121 | "\n", 122 | "for finished in as_completed(futures):\n", 123 | " pprint(finished.result())" 124 | ] 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3 (ipykernel)", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 3 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython3", 143 | "version": "3.10.14" 144 | }, 145 | "toc": { 146 | "base_numbering": 1, 147 | "nav_menu": {}, 148 | "number_sections": true, 149 | "sideBar": true, 150 | "skip_h1_title": false, 151 | "title_cell": "Table of Contents", 152 | "title_sidebar": "Contents", 153 | "toc_cell": false, 154 | "toc_position": {}, 155 | "toc_section_display": true, 156 | "toc_window_display": false 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /examples/pyaudio-streaming/README.md: -------------------------------------------------------------------------------- 1 | # Pyaudio Streaming Example 2 | 3 | This example notebook shows how micophone audio fetched by pyaudio can be processed with Silero-VAD. 4 | 5 | It has been designed as a low-level example for binary real-time streaming using only the prediction of the model, processing the binary data and plotting the speech probabilities at the end to visualize it. 6 | 7 | Currently, the notebook consits of two examples: 8 | - One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards. 9 | - The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter. 10 | 11 | This example does not work in google colab! For local usage only. 12 | 13 | ## Example Video for the Real-Time Visualization 14 | 15 | 16 | https://user-images.githubusercontent.com/8079748/117580455-4622dd00-b0f8-11eb-858d-e6368ed4eada.mp4 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "76aa55ba", 6 | "metadata": {}, 7 | "source": [ 8 | "# Pyaudio Microphone Streaming Examples\n", 9 | "\n", 10 | "A simple notebook that uses pyaudio to get the microphone audio and feeds this audio then to Silero VAD.\n", 11 | "\n", 12 | "I created it as an example on how binary data from a stream could be feed into Silero VAD.\n", 13 | "\n", 14 | "\n", 15 | "Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required.\n", 16 | "\n", 17 | "This notebook does not work in google colab! For local usage only." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "id": "4a4e15c2", 23 | "metadata": {}, 24 | "source": [ 25 | "## Dependencies\n", 26 | "The cell below lists all used dependencies and the used versions. Uncomment to install them from within the notebook." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "id": "24205cce", 33 | "metadata": { 34 | "ExecuteTime": { 35 | "end_time": "2024-10-09T08:47:34.056898Z", 36 | "start_time": "2024-10-09T08:47:34.053418Z" 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "#!pip install numpy>=1.24.0\n", 42 | "#!pip install torch>=1.12.0\n", 43 | "#!pip install matplotlib>=3.6.0\n", 44 | "#!pip install torchaudio>=0.12.0\n", 45 | "#!pip install soundfile==0.12.1\n", 46 | "#!apt install python3-pyaudio (linux) or pip install pyaudio (windows)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "cd22818f", 52 | "metadata": {}, 53 | "source": [ 54 | "## Imports" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "id": "994d7f3a", 61 | "metadata": { 62 | "ExecuteTime": { 63 | "end_time": "2024-10-09T08:47:39.005032Z", 64 | "start_time": "2024-10-09T08:47:36.489952Z" 65 | } 66 | }, 67 | "outputs": [ 68 | { 69 | "ename": "ModuleNotFoundError", 70 | "evalue": "No module named 'pyaudio'", 71 | "output_type": "error", 72 | "traceback": [ 73 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 74 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", 75 | "Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpylab\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyaudio\u001b[39;00m\n", 76 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyaudio'" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "import io\n", 82 | "import numpy as np\n", 83 | "import torch\n", 84 | "torch.set_num_threads(1)\n", 85 | "import torchaudio\n", 86 | "import matplotlib\n", 87 | "import matplotlib.pylab as plt\n", 88 | "import pyaudio" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "ac5c52f7", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", 99 | " model='silero_vad',\n", 100 | " force_reload=True)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "ad5919dc", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "(get_speech_timestamps,\n", 111 | " save_audio,\n", 112 | " read_audio,\n", 113 | " VADIterator,\n", 114 | " collect_chunks) = utils" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "id": "784d1ab6", 120 | "metadata": {}, 121 | "source": [ 122 | "### Helper Methods" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "af4bca64", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "# Taken from utils_vad.py\n", 133 | "def validate(model,\n", 134 | " inputs: torch.Tensor):\n", 135 | " with torch.no_grad():\n", 136 | " outs = model(inputs)\n", 137 | " return outs\n", 138 | "\n", 139 | "# Provided by Alexander Veysov\n", 140 | "def int2float(sound):\n", 141 | " abs_max = np.abs(sound).max()\n", 142 | " sound = sound.astype('float32')\n", 143 | " if abs_max > 0:\n", 144 | " sound *= 1/32768\n", 145 | " sound = sound.squeeze() # depends on the use case\n", 146 | " return sound" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "ca13e514", 152 | "metadata": {}, 153 | "source": [ 154 | "## Pyaudio Set-up" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "75f99022", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "FORMAT = pyaudio.paInt16\n", 165 | "CHANNELS = 1\n", 166 | "SAMPLE_RATE = 16000\n", 167 | "CHUNK = int(SAMPLE_RATE / 10)\n", 168 | "\n", 169 | "audio = pyaudio.PyAudio()" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "id": "4da7d2ef", 175 | "metadata": {}, 176 | "source": [ 177 | "## Simple Example\n", 178 | "The following example reads the audio as 250ms chunks from the microphone, converts them to a Pytorch Tensor, and gets the probabilities/confidences if the model thinks the frame is voiced." 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "6fe77661", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "num_samples = 512" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "23f4da3e", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "stream = audio.open(format=FORMAT,\n", 199 | " channels=CHANNELS,\n", 200 | " rate=SAMPLE_RATE,\n", 201 | " input=True,\n", 202 | " frames_per_buffer=CHUNK)\n", 203 | "data = []\n", 204 | "voiced_confidences = []\n", 205 | "\n", 206 | "frames_to_record = 50\n", 207 | "\n", 208 | "print(\"Started Recording\")\n", 209 | "for i in range(0, frames_to_record):\n", 210 | " \n", 211 | " audio_chunk = stream.read(num_samples)\n", 212 | " \n", 213 | " # in case you want to save the audio later\n", 214 | " data.append(audio_chunk)\n", 215 | " \n", 216 | " audio_int16 = np.frombuffer(audio_chunk, np.int16);\n", 217 | "\n", 218 | " audio_float32 = int2float(audio_int16)\n", 219 | " \n", 220 | " # get the confidences and add them to the list to plot them later\n", 221 | " new_confidence = model(torch.from_numpy(audio_float32), 16000).item()\n", 222 | " voiced_confidences.append(new_confidence)\n", 223 | " \n", 224 | "print(\"Stopped the recording\")\n", 225 | "\n", 226 | "# plot the confidences for the speech\n", 227 | "plt.figure(figsize=(20,6))\n", 228 | "plt.plot(voiced_confidences)\n", 229 | "plt.show()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "fd243e8f", 235 | "metadata": {}, 236 | "source": [ 237 | "## Real Time Visualization\n", 238 | "\n", 239 | "As an enhancement to plot the speech probabilities in real time I added the implementation below.\n", 240 | "In contrast to the simeple one, it records the audio until to stop the recording by pressing enter.\n", 241 | "While looking into good ways to update matplotlib plots in real-time, I found a simple libarary that does the job. https://github.com/lvwerra/jupyterplot It has some limitations, but works for this use case really well.\n" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "d36980c2", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "#!pip install jupyterplot==0.0.3" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "5607b616", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "from jupyterplot import ProgressPlot\n", 262 | "import threading\n", 263 | "\n", 264 | "continue_recording = True\n", 265 | "\n", 266 | "def stop():\n", 267 | " input(\"Press Enter to stop the recording:\")\n", 268 | " global continue_recording\n", 269 | " continue_recording = False\n", 270 | "\n", 271 | "def start_recording():\n", 272 | " \n", 273 | " stream = audio.open(format=FORMAT,\n", 274 | " channels=CHANNELS,\n", 275 | " rate=SAMPLE_RATE,\n", 276 | " input=True,\n", 277 | " frames_per_buffer=CHUNK)\n", 278 | "\n", 279 | " data = []\n", 280 | " voiced_confidences = []\n", 281 | " \n", 282 | " global continue_recording\n", 283 | " continue_recording = True\n", 284 | " \n", 285 | " pp = ProgressPlot(plot_names=[\"Silero VAD\"],line_names=[\"speech probabilities\"], x_label=\"audio chunks\")\n", 286 | " \n", 287 | " stop_listener = threading.Thread(target=stop)\n", 288 | " stop_listener.start()\n", 289 | "\n", 290 | " while continue_recording:\n", 291 | " \n", 292 | " audio_chunk = stream.read(num_samples)\n", 293 | " \n", 294 | " # in case you want to save the audio later\n", 295 | " data.append(audio_chunk)\n", 296 | " \n", 297 | " audio_int16 = np.frombuffer(audio_chunk, np.int16);\n", 298 | "\n", 299 | " audio_float32 = int2float(audio_int16)\n", 300 | " \n", 301 | " # get the confidences and add them to the list to plot them later\n", 302 | " new_confidence = model(torch.from_numpy(audio_float32), 16000).item()\n", 303 | " voiced_confidences.append(new_confidence)\n", 304 | " \n", 305 | " pp.update(new_confidence)\n", 306 | "\n", 307 | "\n", 308 | " pp.finalize()" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "id": "dc4f0108", 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "start_recording()" 319 | ] 320 | } 321 | ], 322 | "metadata": { 323 | "kernelspec": { 324 | "display_name": "Python 3 (ipykernel)", 325 | "language": "python", 326 | "name": "python3" 327 | }, 328 | "language_info": { 329 | "codemirror_mode": { 330 | "name": "ipython", 331 | "version": 3 332 | }, 333 | "file_extension": ".py", 334 | "mimetype": "text/x-python", 335 | "name": "python", 336 | "nbconvert_exporter": "python", 337 | "pygments_lexer": "ipython3", 338 | "version": "3.10.14" 339 | }, 340 | "toc": { 341 | "base_numbering": 1, 342 | "nav_menu": {}, 343 | "number_sections": true, 344 | "sideBar": true, 345 | "skip_h1_title": false, 346 | "title_cell": "Table of Contents", 347 | "title_sidebar": "Contents", 348 | "toc_cell": false, 349 | "toc_position": {}, 350 | "toc_section_display": true, 351 | "toc_window_display": false 352 | } 353 | }, 354 | "nbformat": 4, 355 | "nbformat_minor": 5 356 | } 357 | -------------------------------------------------------------------------------- /examples/rust-example/.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | recorder.wav -------------------------------------------------------------------------------- /examples/rust-example/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rust-example" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] } 8 | ndarray = "0.15" 9 | hound = "3" 10 | -------------------------------------------------------------------------------- /examples/rust-example/README.md: -------------------------------------------------------------------------------- 1 | # Stream example in Rust 2 | Made after [C++ stream example](https://github.com/snakers4/silero-vad/tree/master/examples/cpp) 3 | 4 | ## Dependencies 5 | - To build Rust crate `ort` you need `cc` installed. 6 | 7 | ## Usage 8 | Just 9 | ``` 10 | cargo run 11 | ``` 12 | If you run example outside of this repo adjust environment variable 13 | ``` 14 | SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo run 15 | ``` 16 | If you need to test against other wav file, not `recorder.wav`, specify it as the first argument 17 | ``` 18 | cargo run -- /path/to/audio/file.wav 19 | ``` -------------------------------------------------------------------------------- /examples/rust-example/src/main.rs: -------------------------------------------------------------------------------- 1 | mod silero; 2 | mod utils; 3 | mod vad_iter; 4 | 5 | fn main() { 6 | let model_path = std::env::var("SILERO_MODEL_PATH") 7 | .unwrap_or_else(|_| String::from("../../files/silero_vad.onnx")); 8 | let audio_path = std::env::args() 9 | .nth(1) 10 | .unwrap_or_else(|| String::from("recorder.wav")); 11 | let mut wav_reader = hound::WavReader::open(audio_path).unwrap(); 12 | let sample_rate = match wav_reader.spec().sample_rate { 13 | 8000 => utils::SampleRate::EightkHz, 14 | 16000 => utils::SampleRate::SixteenkHz, 15 | _ => panic!("Unsupported sample rate. Expect 8 kHz or 16 kHz."), 16 | }; 17 | if wav_reader.spec().sample_format != hound::SampleFormat::Int { 18 | panic!("Unsupported sample format. Expect Int."); 19 | } 20 | let content = wav_reader 21 | .samples() 22 | .filter_map(|x| x.ok()) 23 | .collect::>(); 24 | assert!(!content.is_empty()); 25 | let silero = silero::Silero::new(sample_rate, model_path).unwrap(); 26 | let vad_params = utils::VadParams { 27 | sample_rate: sample_rate.into(), 28 | ..Default::default() 29 | }; 30 | let mut vad_iterator = vad_iter::VadIter::new(silero, vad_params); 31 | vad_iterator.process(&content).unwrap(); 32 | for timestamp in vad_iterator.speeches() { 33 | println!("{}", timestamp); 34 | } 35 | println!("Finished."); 36 | } 37 | -------------------------------------------------------------------------------- /examples/rust-example/src/silero.rs: -------------------------------------------------------------------------------- 1 | use crate::utils; 2 | use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr}; 3 | use std::path::Path; 4 | 5 | #[derive(Debug)] 6 | pub struct Silero { 7 | session: ort::Session, 8 | sample_rate: ArrayBase, Dim<[usize; 1]>>, 9 | state: ArrayBase, Dim>, 10 | } 11 | 12 | impl Silero { 13 | pub fn new( 14 | sample_rate: utils::SampleRate, 15 | model_path: impl AsRef, 16 | ) -> Result { 17 | let session = ort::Session::builder()?.commit_from_file(model_path)?; 18 | let state = ArrayD::::zeros([2, 1, 128].as_slice()); 19 | let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap(); 20 | Ok(Self { 21 | session, 22 | sample_rate, 23 | state, 24 | }) 25 | } 26 | 27 | pub fn reset(&mut self) { 28 | self.state = ArrayD::::zeros([2, 1, 128].as_slice()); 29 | } 30 | 31 | pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result { 32 | let data = audio_frame 33 | .iter() 34 | .map(|x| (*x as f32) / (i16::MAX as f32)) 35 | .collect::>(); 36 | let mut frame = Array2::::from_shape_vec([1, data.len()], data).unwrap(); 37 | frame = frame.slice(s![.., ..480]).to_owned(); 38 | let inps = ort::inputs![ 39 | frame, 40 | std::mem::take(&mut self.state), 41 | self.sample_rate.clone(), 42 | ]?; 43 | let res = self 44 | .session 45 | .run(ort::SessionInputs::ValueSlice::<3>(&inps))?; 46 | self.state = res["stateN"].try_extract_tensor().unwrap().to_owned(); 47 | Ok(*res["output"] 48 | .try_extract_raw_tensor::() 49 | .unwrap() 50 | .1 51 | .first() 52 | .unwrap()) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /examples/rust-example/src/utils.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, Copy)] 2 | pub enum SampleRate { 3 | EightkHz, 4 | SixteenkHz, 5 | } 6 | 7 | impl From for i64 { 8 | fn from(value: SampleRate) -> Self { 9 | match value { 10 | SampleRate::EightkHz => 8000, 11 | SampleRate::SixteenkHz => 16000, 12 | } 13 | } 14 | } 15 | 16 | impl From for usize { 17 | fn from(value: SampleRate) -> Self { 18 | match value { 19 | SampleRate::EightkHz => 8000, 20 | SampleRate::SixteenkHz => 16000, 21 | } 22 | } 23 | } 24 | 25 | #[derive(Debug)] 26 | pub struct VadParams { 27 | pub frame_size: usize, 28 | pub threshold: f32, 29 | pub min_silence_duration_ms: usize, 30 | pub speech_pad_ms: usize, 31 | pub min_speech_duration_ms: usize, 32 | pub max_speech_duration_s: f32, 33 | pub sample_rate: usize, 34 | } 35 | 36 | impl Default for VadParams { 37 | fn default() -> Self { 38 | Self { 39 | frame_size: 64, 40 | threshold: 0.5, 41 | min_silence_duration_ms: 0, 42 | speech_pad_ms: 64, 43 | min_speech_duration_ms: 64, 44 | max_speech_duration_s: f32::INFINITY, 45 | sample_rate: 16000, 46 | } 47 | } 48 | } 49 | 50 | #[derive(Debug, Default)] 51 | pub struct TimeStamp { 52 | pub start: i64, 53 | pub end: i64, 54 | } 55 | 56 | impl std::fmt::Display for TimeStamp { 57 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 58 | write!(f, "[start:{:08}, end:{:08}]", self.start, self.end) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /examples/rust-example/src/vad_iter.rs: -------------------------------------------------------------------------------- 1 | use crate::{silero, utils}; 2 | 3 | const DEBUG_SPEECH_PROB: bool = true; 4 | #[derive(Debug)] 5 | pub struct VadIter { 6 | silero: silero::Silero, 7 | params: Params, 8 | state: State, 9 | } 10 | 11 | impl VadIter { 12 | pub fn new(silero: silero::Silero, params: utils::VadParams) -> Self { 13 | Self { 14 | silero, 15 | params: Params::from(params), 16 | state: State::new(), 17 | } 18 | } 19 | 20 | pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> { 21 | self.reset_states(); 22 | for audio_frame in samples.chunks_exact(self.params.frame_size_samples) { 23 | let speech_prob: f32 = self.silero.calc_level(audio_frame)?; 24 | self.state.update(&self.params, speech_prob); 25 | } 26 | self.state.check_for_last_speech(samples.len()); 27 | Ok(()) 28 | } 29 | 30 | pub fn speeches(&self) -> &[utils::TimeStamp] { 31 | &self.state.speeches 32 | } 33 | } 34 | 35 | impl VadIter { 36 | fn reset_states(&mut self) { 37 | self.silero.reset(); 38 | self.state = State::new() 39 | } 40 | } 41 | 42 | #[allow(unused)] 43 | #[derive(Debug)] 44 | struct Params { 45 | frame_size: usize, 46 | threshold: f32, 47 | min_silence_duration_ms: usize, 48 | speech_pad_ms: usize, 49 | min_speech_duration_ms: usize, 50 | max_speech_duration_s: f32, 51 | sample_rate: usize, 52 | sr_per_ms: usize, 53 | frame_size_samples: usize, 54 | min_speech_samples: usize, 55 | speech_pad_samples: usize, 56 | max_speech_samples: f32, 57 | min_silence_samples: usize, 58 | min_silence_samples_at_max_speech: usize, 59 | } 60 | 61 | impl From for Params { 62 | fn from(value: utils::VadParams) -> Self { 63 | let frame_size = value.frame_size; 64 | let threshold = value.threshold; 65 | let min_silence_duration_ms = value.min_silence_duration_ms; 66 | let speech_pad_ms = value.speech_pad_ms; 67 | let min_speech_duration_ms = value.min_speech_duration_ms; 68 | let max_speech_duration_s = value.max_speech_duration_s; 69 | let sample_rate = value.sample_rate; 70 | let sr_per_ms = sample_rate / 1000; 71 | let frame_size_samples = frame_size * sr_per_ms; 72 | let min_speech_samples = sr_per_ms * min_speech_duration_ms; 73 | let speech_pad_samples = sr_per_ms * speech_pad_ms; 74 | let max_speech_samples = sample_rate as f32 * max_speech_duration_s 75 | - frame_size_samples as f32 76 | - 2.0 * speech_pad_samples as f32; 77 | let min_silence_samples = sr_per_ms * min_silence_duration_ms; 78 | let min_silence_samples_at_max_speech = sr_per_ms * 98; 79 | Self { 80 | frame_size, 81 | threshold, 82 | min_silence_duration_ms, 83 | speech_pad_ms, 84 | min_speech_duration_ms, 85 | max_speech_duration_s, 86 | sample_rate, 87 | sr_per_ms, 88 | frame_size_samples, 89 | min_speech_samples, 90 | speech_pad_samples, 91 | max_speech_samples, 92 | min_silence_samples, 93 | min_silence_samples_at_max_speech, 94 | } 95 | } 96 | } 97 | 98 | #[derive(Debug, Default)] 99 | struct State { 100 | current_sample: usize, 101 | temp_end: usize, 102 | next_start: usize, 103 | prev_end: usize, 104 | triggered: bool, 105 | current_speech: utils::TimeStamp, 106 | speeches: Vec, 107 | } 108 | 109 | impl State { 110 | fn new() -> Self { 111 | Default::default() 112 | } 113 | 114 | fn update(&mut self, params: &Params, speech_prob: f32) { 115 | self.current_sample += params.frame_size_samples; 116 | if speech_prob > params.threshold { 117 | if self.temp_end != 0 { 118 | self.temp_end = 0; 119 | if self.next_start < self.prev_end { 120 | self.next_start = self 121 | .current_sample 122 | .saturating_sub(params.frame_size_samples) 123 | } 124 | } 125 | if !self.triggered { 126 | self.debug(speech_prob, params, "start"); 127 | self.triggered = true; 128 | self.current_speech.start = 129 | self.current_sample as i64 - params.frame_size_samples as i64; 130 | } 131 | return; 132 | } 133 | if self.triggered 134 | && (self.current_sample as i64 - self.current_speech.start) as f32 135 | > params.max_speech_samples 136 | { 137 | if self.prev_end > 0 { 138 | self.current_speech.end = self.prev_end as _; 139 | self.take_speech(); 140 | if self.next_start < self.prev_end { 141 | self.triggered = false 142 | } else { 143 | self.current_speech.start = self.next_start as _; 144 | } 145 | self.prev_end = 0; 146 | self.next_start = 0; 147 | self.temp_end = 0; 148 | } else { 149 | self.current_speech.end = self.current_sample as _; 150 | self.take_speech(); 151 | self.prev_end = 0; 152 | self.next_start = 0; 153 | self.temp_end = 0; 154 | self.triggered = false; 155 | } 156 | return; 157 | } 158 | if speech_prob >= (params.threshold - 0.15) && (speech_prob < params.threshold) { 159 | if self.triggered { 160 | self.debug(speech_prob, params, "speaking") 161 | } else { 162 | self.debug(speech_prob, params, "silence") 163 | } 164 | } 165 | if self.triggered && speech_prob < (params.threshold - 0.15) { 166 | self.debug(speech_prob, params, "end"); 167 | if self.temp_end == 0 { 168 | self.temp_end = self.current_sample; 169 | } 170 | if self.current_sample.saturating_sub(self.temp_end) 171 | > params.min_silence_samples_at_max_speech 172 | { 173 | self.prev_end = self.temp_end; 174 | } 175 | if self.current_sample.saturating_sub(self.temp_end) >= params.min_silence_samples { 176 | self.current_speech.end = self.temp_end as _; 177 | if self.current_speech.end - self.current_speech.start 178 | > params.min_speech_samples as _ 179 | { 180 | self.take_speech(); 181 | self.prev_end = 0; 182 | self.next_start = 0; 183 | self.temp_end = 0; 184 | self.triggered = false; 185 | } 186 | } 187 | } 188 | } 189 | 190 | fn take_speech(&mut self) { 191 | self.speeches.push(std::mem::take(&mut self.current_speech)); // current speech becomes TimeStamp::default() due to take() 192 | } 193 | 194 | fn check_for_last_speech(&mut self, last_sample: usize) { 195 | if self.current_speech.start > 0 { 196 | self.current_speech.end = last_sample as _; 197 | self.take_speech(); 198 | self.prev_end = 0; 199 | self.next_start = 0; 200 | self.temp_end = 0; 201 | self.triggered = false; 202 | } 203 | } 204 | 205 | fn debug(&self, speech_prob: f32, params: &Params, title: &str) { 206 | if DEBUG_SPEECH_PROB { 207 | let speech = self.current_sample as f32 208 | - params.frame_size_samples as f32 209 | - if title == "end" { 210 | params.speech_pad_samples 211 | } else { 212 | 0 213 | } as f32; // minus window_size_samples to get precise start time point. 214 | println!( 215 | "[{:10}: {:.3} s ({:.3}) {:8}]", 216 | title, 217 | speech / params.sample_rate as f32, 218 | speech_prob, 219 | self.current_sample - params.frame_size_samples, 220 | ); 221 | } 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /files/silero_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/files/silero_logo.jpg -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ['torch', 'torchaudio'] 2 | import torch 3 | import os 4 | import sys 5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) 6 | from silero_vad.utils_vad import (init_jit_model, 7 | get_speech_timestamps, 8 | save_audio, 9 | read_audio, 10 | VADIterator, 11 | collect_chunks, 12 | OnnxWrapper) 13 | 14 | 15 | def versiontuple(v): 16 | splitted = v.split('+')[0].split(".") 17 | version_list = [] 18 | for i in splitted: 19 | try: 20 | version_list.append(int(i)) 21 | except: 22 | version_list.append(0) 23 | return tuple(version_list) 24 | 25 | 26 | def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16): 27 | """Silero Voice Activity Detector 28 | Returns a model with a set of utils 29 | Please see https://github.com/snakers4/silero-vad for usage examples 30 | """ 31 | available_ops = [15, 16] 32 | if onnx and opset_version not in available_ops: 33 | raise Exception(f'Available ONNX opset_version: {available_ops}') 34 | 35 | if not onnx: 36 | installed_version = torch.__version__ 37 | supported_version = '1.12.0' 38 | if versiontuple(installed_version) < versiontuple(supported_version): 39 | raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)') 40 | 41 | model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data') 42 | if onnx: 43 | if opset_version == 16: 44 | model_name = 'silero_vad.onnx' 45 | else: 46 | model_name = f'silero_vad_16k_op{opset_version}.onnx' 47 | model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu) 48 | else: 49 | model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit')) 50 | utils = (get_speech_timestamps, 51 | save_audio, 52 | read_audio, 53 | VADIterator, 54 | collect_chunks) 55 | 56 | return model, utils 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | [project] 5 | name = "silero-vad" 6 | version = "5.1.2" 7 | authors = [ 8 | {name="Silero Team", email="hello@silero.ai"}, 9 | ] 10 | description = "Voice Activity Detector (VAD) by Silero" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | classifiers = [ 14 | "Development Status :: 5 - Production/Stable", 15 | "License :: OSI Approved :: MIT License", 16 | "Operating System :: OS Independent", 17 | "Intended Audience :: Science/Research", 18 | "Intended Audience :: Developers", 19 | "Programming Language :: Python :: 3.8", 20 | "Programming Language :: Python :: 3.9", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "Topic :: Scientific/Engineering", 26 | ] 27 | dependencies = [ 28 | "torch>=1.12.0", 29 | "torchaudio>=0.12.0", 30 | "onnxruntime>=1.16.1", 31 | ] 32 | 33 | [project.urls] 34 | Homepage = "https://github.com/snakers4/silero-vad" 35 | Issues = "https://github.com/snakers4/silero-vad/issues" 36 | -------------------------------------------------------------------------------- /silero-vad.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "heading_collapsed": true, 7 | "id": "62A6F_072Fwq" 8 | }, 9 | "source": [ 10 | "## Install Dependencies" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "hidden": true, 18 | "id": "5w5AkskZ2Fwr" 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "#@title Install and Import Dependencies\n", 23 | "\n", 24 | "# this assumes that you have a relevant version of PyTorch installed\n", 25 | "!pip install -q torchaudio\n", 26 | "\n", 27 | "SAMPLING_RATE = 16000\n", 28 | "\n", 29 | "import torch\n", 30 | "torch.set_num_threads(1)\n", 31 | "\n", 32 | "from IPython.display import Audio\n", 33 | "from pprint import pprint\n", 34 | "# download example\n", 35 | "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "id": "pSifus5IilRp" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "USE_PIP = True # download model using pip package or torch.hub\n", 47 | "USE_ONNX = False # change this to True if you want to test onnx model\n", 48 | "if USE_ONNX:\n", 49 | " !pip install -q onnxruntime\n", 50 | "if USE_PIP:\n", 51 | " !pip install -q silero-vad\n", 52 | " from silero_vad import (load_silero_vad,\n", 53 | " read_audio,\n", 54 | " get_speech_timestamps,\n", 55 | " save_audio,\n", 56 | " VADIterator,\n", 57 | " collect_chunks)\n", 58 | " model = load_silero_vad(onnx=USE_ONNX)\n", 59 | "else:\n", 60 | " model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n", 61 | " model='silero_vad',\n", 62 | " force_reload=True,\n", 63 | " onnx=USE_ONNX)\n", 64 | "\n", 65 | " (get_speech_timestamps,\n", 66 | " save_audio,\n", 67 | " read_audio,\n", 68 | " VADIterator,\n", 69 | " collect_chunks) = utils" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "fXbbaUO3jsrw" 76 | }, 77 | "source": [ 78 | "## Speech timestapms from full audio" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "id": "aI_eydBPjsrx" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", 90 | "# get speech timestamps from full audio file\n", 91 | "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n", 92 | "pprint(speech_timestamps)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "id": "OuEobLchjsry" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "# merge all speech chunks to one audio\n", 104 | "save_audio('only_speech.wav',\n", 105 | " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n", 106 | "Audio('only_speech.wav')" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": { 112 | "id": "zeO1xCqxUC6w" 113 | }, 114 | "source": [ 115 | "## Entire audio inference" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "id": "LjZBcsaTT7Mk" 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", 127 | "# audio is being splitted into 31.25 ms long pieces\n", 128 | "# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n", 129 | "predicts = model.audio_forward(wav, sr=SAMPLING_RATE)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "id": "iDKQbVr8jsry" 136 | }, 137 | "source": [ 138 | "## Stream imitation example" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "id": "q-lql_2Wjsry" 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "## using VADIterator class\n", 150 | "\n", 151 | "vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n", 152 | "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n", 153 | "\n", 154 | "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n", 155 | "for i in range(0, len(wav), window_size_samples):\n", 156 | " chunk = wav[i: i+ window_size_samples]\n", 157 | " if len(chunk) < window_size_samples:\n", 158 | " break\n", 159 | " speech_dict = vad_iterator(chunk, return_seconds=True)\n", 160 | " if speech_dict:\n", 161 | " print(speech_dict, end=' ')\n", 162 | "vad_iterator.reset_states() # reset model states after each audio" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "id": "BX3UgwwB2Fwv" 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "## just probabilities\n", 174 | "\n", 175 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n", 176 | "speech_probs = []\n", 177 | "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n", 178 | "for i in range(0, len(wav), window_size_samples):\n", 179 | " chunk = wav[i: i+ window_size_samples]\n", 180 | " if len(chunk) < window_size_samples:\n", 181 | " break\n", 182 | " speech_prob = model(chunk, SAMPLING_RATE).item()\n", 183 | " speech_probs.append(speech_prob)\n", 184 | "vad_iterator.reset_states() # reset model states after each audio\n", 185 | "\n", 186 | "print(speech_probs[:10]) # first 10 chunks predicts" 187 | ] 188 | } 189 | ], 190 | "metadata": { 191 | "colab": { 192 | "name": "silero-vad.ipynb", 193 | "provenance": [] 194 | }, 195 | "kernelspec": { 196 | "display_name": "Python 3", 197 | "language": "python", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "codemirror_mode": { 202 | "name": "ipython", 203 | "version": 3 204 | }, 205 | "file_extension": ".py", 206 | "mimetype": "text/x-python", 207 | "name": "python", 208 | "nbconvert_exporter": "python", 209 | "pygments_lexer": "ipython3", 210 | "version": "3.8.8" 211 | }, 212 | "toc": { 213 | "base_numbering": 1, 214 | "nav_menu": {}, 215 | "number_sections": true, 216 | "sideBar": true, 217 | "skip_h1_title": false, 218 | "title_cell": "Table of Contents", 219 | "title_sidebar": "Contents", 220 | "toc_cell": false, 221 | "toc_position": {}, 222 | "toc_section_display": true, 223 | "toc_window_display": false 224 | } 225 | }, 226 | "nbformat": 4, 227 | "nbformat_minor": 0 228 | } 229 | -------------------------------------------------------------------------------- /src/silero_vad/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | try: 3 | __version__ = version(__name__) 4 | except: 5 | pass 6 | 7 | from silero_vad.model import load_silero_vad 8 | from silero_vad.utils_vad import (get_speech_timestamps, 9 | save_audio, 10 | read_audio, 11 | VADIterator, 12 | collect_chunks) -------------------------------------------------------------------------------- /src/silero_vad/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/__init__.py -------------------------------------------------------------------------------- /src/silero_vad/data/silero_vad.jit: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad.jit -------------------------------------------------------------------------------- /src/silero_vad/data/silero_vad.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad.onnx -------------------------------------------------------------------------------- /src/silero_vad/data/silero_vad_16k_op15.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad_16k_op15.onnx -------------------------------------------------------------------------------- /src/silero_vad/data/silero_vad_half.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad_half.onnx -------------------------------------------------------------------------------- /src/silero_vad/model.py: -------------------------------------------------------------------------------- 1 | from .utils_vad import init_jit_model, OnnxWrapper 2 | import torch 3 | torch.set_num_threads(1) 4 | 5 | 6 | def load_silero_vad(onnx=False, opset_version=16): 7 | available_ops = [15, 16] 8 | if onnx and opset_version not in available_ops: 9 | raise Exception(f'Available ONNX opset_version: {available_ops}') 10 | 11 | if onnx: 12 | if opset_version == 16: 13 | model_name = 'silero_vad.onnx' 14 | else: 15 | model_name = f'silero_vad_16k_op{opset_version}.onnx' 16 | else: 17 | model_name = 'silero_vad.jit' 18 | package_path = "silero_vad.data" 19 | 20 | try: 21 | import importlib_resources as impresources 22 | model_file_path = str(impresources.files(package_path).joinpath(model_name)) 23 | except: 24 | from importlib import resources as impresources 25 | try: 26 | with impresources.path(package_path, model_name) as f: 27 | model_file_path = f 28 | except: 29 | model_file_path = str(impresources.files(package_path).joinpath(model_name)) 30 | 31 | if onnx: 32 | model = OnnxWrapper(model_file_path, force_onnx_cpu=True) 33 | else: 34 | model = init_jit_model(model_file_path) 35 | 36 | return model 37 | -------------------------------------------------------------------------------- /tuning/README.md: -------------------------------------------------------------------------------- 1 | # Тюнинг Silero-VAD модели 2 | 3 | > Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный 4 | интеллект» национальной программы «Цифровая экономика Российской Федерации». 5 | 6 | Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных. 7 | 8 | ## Зависимости 9 | Следующие зависимости используются при тюнинге VAD модели: 10 | - `torchaudio>=0.12.0` 11 | - `omegaconf>=2.3.0` 12 | - `sklearn>=1.2.0` 13 | - `torch>=1.12.0` 14 | - `pandas>=2.2.2` 15 | - `tqdm` 16 | 17 | ## Подготовка данных 18 | 19 | Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными: 20 | - **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц; 21 | - **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд. 22 | 23 | Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`. 24 | 25 | Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather` 26 | 27 | ## Файл конфигурации `config.yml` 28 | 29 | Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения: 30 | - `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; 31 | - `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`; 32 | - `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub` 33 | - `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`); 34 | - `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц; 35 | - `model_save_path` - путь сохранения добученной модели; 36 | - `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио; 37 | - `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя; 38 | - `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения; 39 | - `learning_rate` - темп дообучения; 40 | - `batch_size` - размер батча при дообучении и валидации; 41 | - `num_workers` - количество потоков, используемых для загрузки данных; 42 | - `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные; 43 | - `device` - `cpu` или `cuda`. 44 | 45 | ## Дообучение 46 | 47 | Дообучение запускается командой 48 | 49 | `python tune.py` 50 | 51 | Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit. 52 | 53 | ## Поиск пороговых значений 54 | 55 | Порог на вход и порог на выход можно подобрать, используя команду 56 | 57 | `python search_thresholds` 58 | 59 | Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете. 60 | 61 | ## Цитирование 62 | 63 | ``` 64 | @misc{Silero VAD, 65 | author = {Silero Team}, 66 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier}, 67 | year = {2024}, 68 | publisher = {GitHub}, 69 | journal = {GitHub repository}, 70 | howpublished = {\url{https://github.com/snakers4/silero-vad}}, 71 | commit = {insert_some_commit_here}, 72 | email = {hello@silero.ai} 73 | } 74 | ``` 75 | -------------------------------------------------------------------------------- /tuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/tuning/__init__.py -------------------------------------------------------------------------------- /tuning/config.yml: -------------------------------------------------------------------------------- 1 | jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически 2 | use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False 3 | 4 | tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True 5 | train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README 6 | val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README 7 | model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели 8 | 9 | noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах 10 | max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения 11 | aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения 12 | 13 | learning_rate: 5e-4 # темп дообучения модели 14 | batch_size: 128 # размер батча при дообучении и валидации 15 | num_workers: 4 # количество потоков, используемых для даталоадеров 16 | num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных 17 | device: 'cuda' # cpu или cuda, на чем будет производится дообучение -------------------------------------------------------------------------------- /tuning/example_dataframe.feather: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/tuning/example_dataframe.feather -------------------------------------------------------------------------------- /tuning/search_thresholds.py: -------------------------------------------------------------------------------- 1 | from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder 2 | from omegaconf import OmegaConf 3 | import torch 4 | torch.set_num_threads(1) 5 | 6 | if __name__ == '__main__': 7 | config = OmegaConf.load('config.yml') 8 | 9 | loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'), 10 | batch_size=config.batch_size, 11 | collate_fn=SileroVadPadder, 12 | num_workers=config.num_workers) 13 | 14 | if config.jit_model_path: 15 | print(f'Loading model from the local folder: {config.jit_model_path}') 16 | model = init_jit_model(config.jit_model_path, device=config.device) 17 | else: 18 | if config.use_torchhub: 19 | print('Loading model using torch.hub') 20 | model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', 21 | model='silero_vad', 22 | onnx=False, 23 | force_reload=True) 24 | else: 25 | print('Loading model using silero-vad library') 26 | from silero_vad import load_silero_vad 27 | model = load_silero_vad(onnx=False) 28 | 29 | print('Model loaded') 30 | model.to(config.device) 31 | 32 | print('Making predicts...') 33 | all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000) 34 | print('Calculating thresholds...') 35 | best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts) 36 | print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}') 37 | -------------------------------------------------------------------------------- /tuning/tune.py: -------------------------------------------------------------------------------- 1 | from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model 2 | from omegaconf import OmegaConf 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | if __name__ == '__main__': 8 | config = OmegaConf.load('config.yml') 9 | 10 | train_dataset = SileroVadDataset(config, mode='train') 11 | train_loader = torch.utils.data.DataLoader(train_dataset, 12 | batch_size=config.batch_size, 13 | collate_fn=SileroVadPadder, 14 | num_workers=config.num_workers) 15 | 16 | val_dataset = SileroVadDataset(config, mode='val') 17 | val_loader = torch.utils.data.DataLoader(val_dataset, 18 | batch_size=config.batch_size, 19 | collate_fn=SileroVadPadder, 20 | num_workers=config.num_workers) 21 | 22 | if config.jit_model_path: 23 | print(f'Loading model from the local folder: {config.jit_model_path}') 24 | model = init_jit_model(config.jit_model_path, device=config.device) 25 | else: 26 | if config.use_torchhub: 27 | print('Loading model using torch.hub') 28 | model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad', 29 | model='silero_vad', 30 | onnx=False, 31 | force_reload=True) 32 | else: 33 | print('Loading model using silero-vad library') 34 | from silero_vad import load_silero_vad 35 | model = load_silero_vad(onnx=False) 36 | 37 | print('Model loaded') 38 | model.to(config.device) 39 | decoder = VADDecoderRNNJIT().to(config.device) 40 | decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict()) 41 | decoder.train() 42 | params = decoder.parameters() 43 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params), 44 | lr=config.learning_rate) 45 | criterion = nn.BCELoss(reduction='none') 46 | 47 | best_val_roc = 0 48 | for i in range(config.num_epochs): 49 | print(f'Starting epoch {i + 1}') 50 | train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device) 51 | val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device) 52 | print(f'Metrics after epoch {i + 1}:\n' 53 | f'\tTrain loss: {round(train_loss, 3)}\n', 54 | f'\tValidation loss: {round(val_loss, 3)}\n' 55 | f'\tValidation ROC-AUC: {round(val_roc, 3)}') 56 | 57 | if val_roc > best_val_roc: 58 | print('New best ROC-AUC, saving model') 59 | best_val_roc = val_roc 60 | if config.tune_8k: 61 | model._model_8k.decoder.load_state_dict(decoder.state_dict()) 62 | else: 63 | model._model.decoder.load_state_dict(decoder.state_dict()) 64 | torch.jit.save(model, config.model_save_path) 65 | print('Done') 66 | -------------------------------------------------------------------------------- /tuning/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score, accuracy_score 2 | from torch.utils.data import Dataset 3 | import torch.nn as nn 4 | from tqdm import tqdm 5 | import pandas as pd 6 | import numpy as np 7 | import torchaudio 8 | import warnings 9 | import random 10 | import torch 11 | import gc 12 | warnings.filterwarnings('ignore') 13 | 14 | 15 | def read_audio(path: str, 16 | sampling_rate: int = 16000, 17 | normalize=False): 18 | 19 | wav, sr = torchaudio.load(path) 20 | 21 | if wav.size(0) > 1: 22 | wav = wav.mean(dim=0, keepdim=True) 23 | 24 | if sampling_rate: 25 | if sr != sampling_rate: 26 | transform = torchaudio.transforms.Resample(orig_freq=sr, 27 | new_freq=sampling_rate) 28 | wav = transform(wav) 29 | sr = sampling_rate 30 | 31 | if normalize and wav.abs().max() != 0: 32 | wav = wav / wav.abs().max() 33 | 34 | return wav.squeeze(0) 35 | 36 | 37 | def build_audiomentations_augs(p): 38 | from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \ 39 | LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \ 40 | Aliasing, AddGaussianNoise 41 | transforms = [Aliasing(p=1), 42 | AddGaussianNoise(p=1), 43 | AirAbsorption(p=1), 44 | BandPassFilter(p=1), 45 | BandStopFilter(p=1), 46 | ClippingDistortion(p=1), 47 | HighPassFilter(p=1), 48 | HighShelfFilter(p=1), 49 | LowPassFilter(p=1), 50 | LowShelfFilter(p=1), 51 | Mp3Compression(p=1), 52 | PeakingFilter(p=1), 53 | PitchShift(p=1), 54 | RoomSimulator(p=1, leave_length_unchanged=True), 55 | SevenBandParametricEQ(p=1)] 56 | tr = SomeOf((1, 3), transforms=transforms, p=p) 57 | return tr 58 | 59 | 60 | class SileroVadDataset(Dataset): 61 | def __init__(self, 62 | config, 63 | mode='train'): 64 | 65 | self.num_samples = 512 # constant, do not change 66 | self.sr = 16000 # constant, do not change 67 | 68 | self.resample_to_8k = config.tune_8k 69 | self.noise_loss = config.noise_loss 70 | self.max_train_length_sec = config.max_train_length_sec 71 | self.max_train_length_samples = config.max_train_length_sec * self.sr 72 | 73 | assert self.max_train_length_samples % self.num_samples == 0 74 | assert mode in ['train', 'val'] 75 | 76 | dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path 77 | self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True) 78 | self.index_dict = self.dataframe.to_dict('index') 79 | self.mode = mode 80 | print(f'DATASET SIZE : {len(self.dataframe)}') 81 | 82 | if mode == 'train': 83 | self.augs = build_audiomentations_augs(p=config.aug_prob) 84 | else: 85 | self.augs = None 86 | 87 | def __getitem__(self, idx): 88 | idx = None if self.mode == 'train' else idx 89 | wav, gt, mask = self.load_speech_sample(idx) 90 | 91 | if self.mode == 'train': 92 | wav = self.add_augs(wav) 93 | if len(wav) > self.max_train_length_samples: 94 | wav = wav[:self.max_train_length_samples] 95 | gt = gt[:int(self.max_train_length_samples / self.num_samples)] 96 | mask = mask[:int(self.max_train_length_samples / self.num_samples)] 97 | 98 | wav = torch.FloatTensor(wav) 99 | if self.resample_to_8k: 100 | transform = torchaudio.transforms.Resample(orig_freq=self.sr, 101 | new_freq=8000) 102 | wav = transform(wav) 103 | return wav, torch.FloatTensor(gt), torch.from_numpy(mask) 104 | 105 | def __len__(self): 106 | return len(self.index_dict) 107 | 108 | def load_speech_sample(self, idx=None): 109 | if idx is None: 110 | idx = random.randint(0, len(self.index_dict) - 1) 111 | wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy() 112 | 113 | if len(wav) % self.num_samples != 0: 114 | pad_num = self.num_samples - (len(wav) % (self.num_samples)) 115 | wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0) 116 | 117 | gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav)) 118 | 119 | assert len(gt) == len(wav) / self.num_samples 120 | 121 | mask[gt == 0] 122 | 123 | return wav, gt, mask 124 | 125 | def get_ground_truth_annotated(self, annotation, audio_length_samples): 126 | gt = np.zeros(audio_length_samples) 127 | 128 | for i in annotation: 129 | gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1 130 | 131 | squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1) 132 | squeezed_predicts = (squeezed_predicts > 0.5).astype(int) 133 | mask = np.ones(len(squeezed_predicts)) 134 | mask[squeezed_predicts == 0] = self.noise_loss 135 | return squeezed_predicts, mask 136 | 137 | def add_augs(self, wav): 138 | while True: 139 | try: 140 | wav_aug = self.augs(wav, self.sr) 141 | if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()): 142 | return wav 143 | return wav_aug 144 | except Exception as e: 145 | continue 146 | 147 | 148 | def SileroVadPadder(batch): 149 | wavs = [batch[i][0] for i in range(len(batch))] 150 | labels = [batch[i][1] for i in range(len(batch))] 151 | masks = [batch[i][2] for i in range(len(batch))] 152 | 153 | wavs = torch.nn.utils.rnn.pad_sequence( 154 | wavs, batch_first=True, padding_value=0) 155 | 156 | labels = torch.nn.utils.rnn.pad_sequence( 157 | labels, batch_first=True, padding_value=0) 158 | 159 | masks = torch.nn.utils.rnn.pad_sequence( 160 | masks, batch_first=True, padding_value=0) 161 | 162 | return wavs, labels, masks 163 | 164 | 165 | class VADDecoderRNNJIT(nn.Module): 166 | 167 | def __init__(self): 168 | super(VADDecoderRNNJIT, self).__init__() 169 | 170 | self.rnn = nn.LSTMCell(128, 128) 171 | self.decoder = nn.Sequential(nn.Dropout(0.1), 172 | nn.ReLU(), 173 | nn.Conv1d(128, 1, kernel_size=1), 174 | nn.Sigmoid()) 175 | 176 | def forward(self, x, state=torch.zeros(0)): 177 | x = x.squeeze(-1) 178 | if len(state): 179 | h, c = self.rnn(x, (state[0], state[1])) 180 | else: 181 | h, c = self.rnn(x) 182 | 183 | x = h.unsqueeze(-1).float() 184 | state = torch.stack([h, c]) 185 | x = self.decoder(x) 186 | return x, state 187 | 188 | 189 | class AverageMeter(object): 190 | """Computes and stores the average and current value""" 191 | 192 | def __init__(self): 193 | self.reset() 194 | 195 | def reset(self): 196 | self.val = 0 197 | self.avg = 0 198 | self.sum = 0 199 | self.count = 0 200 | 201 | def update(self, val, n=1): 202 | self.val = val 203 | self.sum += val * n 204 | self.count += n 205 | self.avg = self.sum / self.count 206 | 207 | 208 | def train(config, 209 | loader, 210 | jit_model, 211 | decoder, 212 | criterion, 213 | optimizer, 214 | device): 215 | 216 | losses = AverageMeter() 217 | decoder.train() 218 | 219 | context_size = 32 if config.tune_8k else 64 220 | num_samples = 256 if config.tune_8k else 512 221 | stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft 222 | encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder 223 | 224 | with torch.enable_grad(): 225 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): 226 | targets = targets.to(device) 227 | x = x.to(device) 228 | masks = masks.to(device) 229 | x = torch.nn.functional.pad(x, (context_size, 0)) 230 | 231 | outs = [] 232 | state = torch.zeros(0) 233 | for i in range(context_size, x.shape[1], num_samples): 234 | input_ = x[:, i-context_size:i+num_samples] 235 | out = stft_layer(input_) 236 | out = encoder_layer(out) 237 | out, state = decoder(out, state) 238 | outs.append(out) 239 | stacked = torch.cat(outs, dim=2).squeeze(1) 240 | 241 | loss = criterion(stacked, targets) 242 | loss = (loss * masks).mean() 243 | loss.backward() 244 | optimizer.step() 245 | losses.update(loss.item(), masks.numel()) 246 | 247 | torch.cuda.empty_cache() 248 | gc.collect() 249 | 250 | return losses.avg 251 | 252 | 253 | def validate(config, 254 | loader, 255 | jit_model, 256 | decoder, 257 | criterion, 258 | device): 259 | 260 | losses = AverageMeter() 261 | decoder.eval() 262 | 263 | predicts = [] 264 | gts = [] 265 | 266 | context_size = 32 if config.tune_8k else 64 267 | num_samples = 256 if config.tune_8k else 512 268 | stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft 269 | encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder 270 | 271 | with torch.no_grad(): 272 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): 273 | targets = targets.to(device) 274 | x = x.to(device) 275 | masks = masks.to(device) 276 | x = torch.nn.functional.pad(x, (context_size, 0)) 277 | 278 | outs = [] 279 | state = torch.zeros(0) 280 | for i in range(context_size, x.shape[1], num_samples): 281 | input_ = x[:, i-context_size:i+num_samples] 282 | out = stft_layer(input_) 283 | out = encoder_layer(out) 284 | out, state = decoder(out, state) 285 | outs.append(out) 286 | stacked = torch.cat(outs, dim=2).squeeze(1) 287 | 288 | predicts.extend(stacked[masks != 0].tolist()) 289 | gts.extend(targets[masks != 0].tolist()) 290 | 291 | loss = criterion(stacked, targets) 292 | loss = (loss * masks).mean() 293 | losses.update(loss.item(), masks.numel()) 294 | score = roc_auc_score(gts, predicts) 295 | 296 | torch.cuda.empty_cache() 297 | gc.collect() 298 | 299 | return losses.avg, round(score, 3) 300 | 301 | 302 | def init_jit_model(model_path: str, 303 | device=torch.device('cpu')): 304 | torch.set_grad_enabled(False) 305 | model = torch.jit.load(model_path, map_location=device) 306 | model.eval() 307 | return model 308 | 309 | 310 | def predict(model, loader, device, sr): 311 | with torch.no_grad(): 312 | all_predicts = [] 313 | all_gts = [] 314 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)): 315 | x = x.to(device) 316 | out = model.audio_forward(x, sr=sr) 317 | 318 | for i, out_chunk in enumerate(out): 319 | predict = out_chunk[masks[i] != 0].cpu().tolist() 320 | gt = targets[i, masks[i] != 0].cpu().tolist() 321 | 322 | all_predicts.append(predict) 323 | all_gts.append(gt) 324 | return all_predicts, all_gts 325 | 326 | 327 | def calculate_best_thresholds(all_predicts, all_gts): 328 | best_acc = 0 329 | for ths_enter in tqdm(np.linspace(0, 1, 20)): 330 | for ths_exit in np.linspace(0, 1, 20): 331 | if ths_exit >= ths_enter: 332 | continue 333 | 334 | accs = [] 335 | for j, predict in enumerate(all_predicts): 336 | predict_bool = [] 337 | is_speech = False 338 | for i in predict: 339 | if i >= ths_enter: 340 | is_speech = True 341 | predict_bool.append(1) 342 | elif i <= ths_exit: 343 | is_speech = False 344 | predict_bool.append(0) 345 | else: 346 | val = 1 if is_speech else 0 347 | predict_bool.append(val) 348 | 349 | score = round(accuracy_score(all_gts[j], predict_bool), 4) 350 | accs.append(score) 351 | 352 | mean_acc = round(np.mean(accs), 3) 353 | if mean_acc > best_acc: 354 | best_acc = mean_acc 355 | best_ths_enter = round(ths_enter, 2) 356 | best_ths_exit = round(ths_exit, 2) 357 | return best_ths_enter, best_ths_exit, best_acc 358 | --------------------------------------------------------------------------------