├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── questions---help---support.md
└── workflows
│ └── python-publish.yml
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── datasets
└── README.md
├── examples
├── colab_record_example.ipynb
├── cpp
│ ├── README.md
│ ├── silero-vad-onnx.cpp
│ └── wav.h
├── cpp_libtorch
│ ├── README.md
│ ├── aepyx.wav
│ ├── main.cc
│ ├── silero
│ ├── silero_torch.cc
│ ├── silero_torch.h
│ └── wav.h
├── csharp
│ ├── Program.cs
│ ├── SileroSpeechSegment.cs
│ ├── SileroVadDetector.cs
│ ├── SileroVadOnnxModel.cs
│ ├── VadDotNet.csproj
│ └── resources
│ │ └── put_model_here.txt
├── go
│ ├── README.md
│ ├── cmd
│ │ └── main.go
│ ├── go.mod
│ └── go.sum
├── haskell
│ ├── README.md
│ ├── app
│ │ └── Main.hs
│ ├── example.cabal
│ ├── package.yaml
│ ├── stack.yaml
│ └── stack.yaml.lock
├── java-example
│ ├── pom.xml
│ └── src
│ │ └── main
│ │ └── java
│ │ └── org
│ │ └── example
│ │ ├── App.java
│ │ ├── SlieroVadDetector.java
│ │ └── SlieroVadOnnxModel.java
├── java-wav-file-example
│ └── src
│ │ └── main
│ │ └── java
│ │ └── org
│ │ └── example
│ │ ├── App.java
│ │ ├── SileroSpeechSegment.java
│ │ ├── SileroVadDetector.java
│ │ └── SileroVadOnnxModel.java
├── microphone_and_webRTC_integration
│ ├── README.md
│ └── microphone_and_webRTC_integration.py
├── parallel_example.ipynb
├── pyaudio-streaming
│ ├── README.md
│ └── pyaudio-streaming-examples.ipynb
└── rust-example
│ ├── .gitignore
│ ├── Cargo.lock
│ ├── Cargo.toml
│ ├── README.md
│ └── src
│ ├── main.rs
│ ├── silero.rs
│ ├── utils.rs
│ └── vad_iter.rs
├── files
└── silero_logo.jpg
├── hubconf.py
├── pyproject.toml
├── silero-vad.ipynb
├── src
└── silero_vad
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── silero_vad.jit
│ ├── silero_vad.onnx
│ ├── silero_vad_16k_op15.onnx
│ └── silero_vad_half.onnx
│ ├── model.py
│ └── utils_vad.py
└── tuning
├── README.md
├── __init__.py
├── config.yml
├── example_dataframe.feather
├── search_thresholds.py
├── tune.py
└── utils.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: Bug report - [X]
5 | labels: bug
6 | assignees: snakers4
7 |
8 | ---
9 |
10 | ## 🐛 Bug
11 |
12 |
13 |
14 | ## To Reproduce
15 |
16 | Steps to reproduce the behavior:
17 |
18 | 1.
19 | 2.
20 | 3.
21 |
22 |
23 |
24 | ## Expected behavior
25 |
26 |
27 |
28 | ## Environment
29 |
30 | Please copy and paste the output from this
31 | [environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py)
32 | (or fill out the checklist below manually).
33 |
34 | You can get the script and run it with:
35 | ```
36 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
37 | # For security purposes, please check the contents of collect_env.py before running it.
38 | python collect_env.py
39 | ```
40 |
41 | - PyTorch Version (e.g., 1.0):
42 | - OS (e.g., Linux):
43 | - How you installed PyTorch (`conda`, `pip`, source):
44 | - Build command you used (if compiling from source):
45 | - Python version:
46 | - CUDA/cuDNN version:
47 | - GPU models and configuration:
48 | - Any other relevant information:
49 |
50 | ## Additional context
51 |
52 |
53 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: Feature request - [X]
5 | labels: enhancement
6 | assignees: snakers4
7 |
8 | ---
9 |
10 | ## 🚀 Feature
11 |
12 |
13 | ## Motivation
14 |
15 |
16 |
17 | ## Pitch
18 |
19 |
20 |
21 | ## Alternatives
22 |
23 |
24 |
25 | ## Additional context
26 |
27 |
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/questions---help---support.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Questions / Help / Support
3 | about: Ask for help, support or ask a question
4 | title: "❓ Questions / Help / Support"
5 | labels: help wanted
6 | assignees: snakers4
7 |
8 | ---
9 |
10 | ## ❓ Questions and Help
11 |
12 | We have a [wiki](https://github.com/snakers4/silero-models/wiki) available for our users. Please make sure you have checked it out first.
13 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | push:
13 | tags:
14 | - '*'
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 |
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 | - name: Set up Python
27 | uses: actions/setup-python@v3
28 | with:
29 | python-version: '3.x'
30 | - name: Install dependencies
31 | run: |
32 | python -m pip install --upgrade pip
33 | pip install build
34 | - name: Build package
35 | run: python -m build
36 | - name: Publish package
37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
38 | with:
39 | user: __token__
40 | password: ${{ secrets.PYPI_API_TOKEN }}
41 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | title: "Silero VAD"
4 | authors:
5 | - family-names: "Silero Team"
6 | email: "hello@silero.ai"
7 | type: software
8 | repository-code: "https://github.com/snakers4/silero-vad"
9 | license: MIT
10 | abstract: "Pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
11 | preferred-citation:
12 | type: software
13 | authors:
14 | - family-names: "Silero Team"
15 | email: "hello@silero.ai"
16 | title: "Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier"
17 | year: 2024
18 | publisher: "GitHub"
19 | journal: "GitHub repository"
20 | howpublished: "https://github.com/snakers4/silero-vad"
21 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at aveysov@gmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020-present Silero Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](mailto:hello@silero.ai) [](https://t.me/silero_speech) [](https://github.com/snakers4/silero-vad/blob/master/LICENSE) [](https://pypi.org/project/silero-vad/)
2 |
3 | [](https://colab.research.google.com/github/snakers4/silero-vad/blob/master/silero-vad.ipynb)
4 |
5 | 
6 |
7 |
8 |
Silero VAD
9 |
10 |
11 | **Silero VAD** - pre-trained enterprise-grade [Voice Activity Detector](https://en.wikipedia.org/wiki/Voice_activity_detection) (also see our [STT models](https://github.com/snakers4/silero-models)).
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | Real Time Example
22 |
23 | https://user-images.githubusercontent.com/36505480/144874384-95f80f6d-a4f1-42cc-9be7-004c891dd481.mp4
24 |
25 | Please note, that video loads only if you are logged in your GitHub account.
26 |
27 |
28 |
29 |
30 |
31 | Fast start
32 |
33 |
34 |
35 | Dependencies
36 |
37 | System requirements to run python examples on `x86-64` systems:
38 |
39 | - `python 3.8+`;
40 | - 1G+ RAM;
41 | - A modern CPU with AVX, AVX2, AVX-512 or AMX instruction sets.
42 |
43 | Dependencies:
44 |
45 | - `torch>=1.12.0`;
46 | - `torchaudio>=0.12.0` (for I/O only);
47 | - `onnxruntime>=1.16.1` (for ONNX model usage).
48 |
49 | Silero VAD uses torchaudio library for audio I/O (`torchaudio.info`, `torchaudio.load`, and `torchaudio.save`), so a proper audio backend is required:
50 |
51 | - Option №1 - [**FFmpeg**](https://www.ffmpeg.org/) backend. `conda install -c conda-forge 'ffmpeg<7'`;
52 | - Option №2 - [**sox_io**](https://pypi.org/project/sox/) backend. `apt-get install sox`, TorchAudio is tested on libsox 14.4.2;
53 | - Option №3 - [**soundfile**](https://pypi.org/project/soundfile/) backend. `pip install soundfile`.
54 |
55 | If you are planning to run the VAD using solely the `onnx-runtime`, it will run on any other system architectures where onnx-runtume is [supported](https://onnxruntime.ai/getting-started). In this case please note that:
56 |
57 | - You will have to implement the I/O;
58 | - You will have to adapt the existing wrappers / examples / post-processing for your use-case.
59 |
60 |
61 |
62 | **Using pip**:
63 | `pip install silero-vad`
64 |
65 | ```python3
66 | from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
67 | model = load_silero_vad()
68 | wav = read_audio('path_to_audio_file')
69 | speech_timestamps = get_speech_timestamps(
70 | wav,
71 | model,
72 | return_seconds=True, # Return speech timestamps in seconds (default is samples)
73 | )
74 | ```
75 |
76 | **Using torch.hub**:
77 | ```python3
78 | import torch
79 | torch.set_num_threads(1)
80 |
81 | model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
82 | (get_speech_timestamps, _, read_audio, _, _) = utils
83 |
84 | wav = read_audio('path_to_audio_file')
85 | speech_timestamps = get_speech_timestamps(
86 | wav,
87 | model,
88 | return_seconds=True, # Return speech timestamps in seconds (default is samples)
89 | )
90 | ```
91 |
92 |
93 |
94 | Key Features
95 |
96 |
97 | - **Stellar accuracy**
98 |
99 | Silero VAD has [excellent results](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#vs-other-available-solutions) on speech detection tasks.
100 |
101 | - **Fast**
102 |
103 | One audio chunk (30+ ms) [takes](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics#silero-vad-performance-metrics) less than **1ms** to be processed on a single CPU thread. Using batching or GPU can also improve performance considerably. Under certain conditions ONNX may even run up to 4-5x faster.
104 |
105 | - **Lightweight**
106 |
107 | JIT model is around two megabytes in size.
108 |
109 | - **General**
110 |
111 | Silero VAD was trained on huge corpora that include over **6000** languages and it performs well on audios from different domains with various background noise and quality levels.
112 |
113 | - **Flexible sampling rate**
114 |
115 | Silero VAD [supports](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics#sample-rate-comparison) **8000 Hz** and **16000 Hz** [sampling rates](https://en.wikipedia.org/wiki/Sampling_(signal_processing)#Sampling_rate).
116 |
117 | - **Highly Portable**
118 |
119 | Silero VAD reaps benefits from the rich ecosystems built around **PyTorch** and **ONNX** running everywhere where these runtimes are available.
120 |
121 | - **No Strings Attached**
122 |
123 | Published under permissive license (MIT) Silero VAD has zero strings attached - no telemetry, no keys, no registration, no built-in expiration, no keys or vendor lock.
124 |
125 |
126 |
127 | Typical Use Cases
128 |
129 |
130 | - Voice activity detection for IOT / edge / mobile use cases
131 | - Data cleaning and preparation, voice detection in general
132 | - Telephony and call-center automation, voice bots
133 | - Voice interfaces
134 |
135 |
136 | Links
137 |
138 |
139 |
140 | - [Examples and Dependencies](https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies)
141 | - [Quality Metrics](https://github.com/snakers4/silero-vad/wiki/Quality-Metrics)
142 | - [Performance Metrics](https://github.com/snakers4/silero-vad/wiki/Performance-Metrics)
143 | - [Versions and Available Models](https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)
144 | - [Further reading](https://github.com/snakers4/silero-models#further-reading)
145 | - [FAQ](https://github.com/snakers4/silero-vad/wiki/FAQ)
146 |
147 |
148 | Get In Touch
149 |
150 |
151 | Try our models, create an [issue](https://github.com/snakers4/silero-vad/issues/new), start a [discussion](https://github.com/snakers4/silero-vad/discussions/new), join our telegram [chat](https://t.me/silero_speech), [email](mailto:hello@silero.ai) us, read our [news](https://t.me/silero_news).
152 |
153 | Please see our [wiki](https://github.com/snakers4/silero-models/wiki) for relevant information and [email](mailto:hello@silero.ai) us directly.
154 |
155 | **Citations**
156 |
157 | ```
158 | @misc{Silero VAD,
159 | author = {Silero Team},
160 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
161 | year = {2024},
162 | publisher = {GitHub},
163 | journal = {GitHub repository},
164 | howpublished = {\url{https://github.com/snakers4/silero-vad}},
165 | commit = {insert_some_commit_here},
166 | email = {hello@silero.ai}
167 | }
168 | ```
169 |
170 |
171 | Examples and VAD-based Community Apps
172 |
173 |
174 | - Example of VAD ONNX Runtime model usage in [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
175 |
176 | - Voice activity detection for the [browser](https://github.com/ricky0123/vad) using ONNX Runtime Web
177 |
178 | - [Rust](https://github.com/snakers4/silero-vad/tree/master/examples/rust-example), [Go](https://github.com/snakers4/silero-vad/tree/master/examples/go), [Java](https://github.com/snakers4/silero-vad/tree/master/examples/java-example), [C++](https://github.com/snakers4/silero-vad/tree/master/examples/cpp), [C#](https://github.com/snakers4/silero-vad/tree/master/examples/csharp) and [other](https://github.com/snakers4/silero-vad/tree/master/examples) community examples
179 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Датасет Silero-VAD
2 |
3 | > Датасет создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
4 | интеллект» национальной программы «Цифровая экономика Российской Федерации».
5 |
6 | По ссылкам ниже представлены `.feather` файлы, содержащие размеченные с помощью Silero VAD открытые наборы аудиоданных, а также короткое описание каждого набора данных с примерами загрузки. `.feather` файлы можно открыть с помощью библиотеки `pandas`:
7 | ```python3
8 | import pandas as pd
9 | dataframe = pd.read_feather(PATH_TO_FEATHER_FILE)
10 | ```
11 |
12 | Каждый `.feather` файл с разметкой содержит следующие колонки:
13 | - `speech_timings` - разметка данного аудио. Это список, содержащий словари вида `{'start': START_SECOND, 'end': END_SECOND}`, где `START_SECOND` и `END_SECOND` - время начала и конца речи в секундах. Количество данных словарей равно количеству речевых аудио отрывков, найденных в данном аудио;
14 | - `language` - ISO код языка данного аудио.
15 |
16 | Колонки, содержащие информацию о загрузке аудио файла различаются и описаны для каждого набора данных ниже.
17 |
18 | **Все данные размечены при временной дискретизации в ~30 миллисекунд (`num_samples` - 512)**
19 |
20 | | Название | Число часов | Число языков | Ссылка | Лицензия | md5sum |
21 | |----------------------|-------------|-------------|--------|----------|----------|
22 | | **Bible.is** | 53,138 | 1,596 | [URL](https://live.bible.is/) | [Уникальная](https://live.bible.is/terms) | ea404eeaf2cd283b8223f63002be11f9 |
23 | | **globalrecordings.net** | 9,743 | 6,171[^1] | [URL](https://globalrecordings.net/en) | CC BY-NC-SA 4.0 | 3c5c0f31b0abd9fe94ddbe8b1e2eb326 |
24 | | **VoxLingua107** | 6,628 | 107 | [URL](https://bark.phon.ioc.ee/voxlingua107/) | CC BY 4.0 | 5dfef33b4d091b6d399cfaf3d05f2140 |
25 | | **Common Voice** | 30,329 | 120 | [URL](https://commonvoice.mozilla.org/en/datasets) | CC0 | 5e30a85126adf74a5fd1496e6ac8695d |
26 | | **MLS** | 50,709 | 8 | [URL](https://www.openslr.org/94/) | CC BY 4.0 | a339d0e94bdf41bba3c003756254ac4e |
27 | | **Итого** | **150,547** | **6,171+** | | | |
28 |
29 | ## Bible.is
30 |
31 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/BibleIs.feather)
32 |
33 | - Колонка `audio_link` содержит ссылки на конкретные аудио файлы.
34 |
35 | ## globalrecordings.net
36 |
37 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/globalrecordings.feather)
38 |
39 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
40 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
41 |
42 | ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
43 |
44 | ## VoxLingua107
45 |
46 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/VoxLingua107.feather)
47 |
48 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
49 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
50 |
51 | ## Common Voice
52 |
53 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/common_voice.feather)
54 |
55 | Этот датасет невозможно скачать по статичным ссылкам. Для загрузки необходимо перейти по [ссылке](https://commonvoice.mozilla.org/en/datasets) и, получив доступ в соответствующей форме, скачать архивы для каждого доступного языка. Внимание! Представленная разметка актуальна для версии исходного датасета `Common Voice Corpus 16.1`.
56 |
57 | - Колонка `audio_path` содержит уникальные названия `.mp3` файлов, полученных после скачивания соответствующего датасета.
58 |
59 | ## MLS
60 |
61 | [Ссылка на `.feather` файл с разметкой](https://models.silero.ai/vad_datasets/MLS.feather)
62 |
63 | - Колонка `folder_link` содержит ссылки на скачивание `.zip` архива для конкретного языка. Внимание! Ссылки на архивы дублируются, т.к каждый архив может содержать множество аудио.
64 | - Колонка `audio_path` содержит пути до конкретного аудио после распаковки соответствующего архива из колонки `folder_link`
65 |
66 | ## Лицензия
67 |
68 | Данный датасет распространяется под [лицензией](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) `CC BY-NC-SA 4.0`.
69 |
70 | ## Цитирование
71 |
72 | ```
73 | @misc{Silero VAD Dataset,
74 | author = {Silero Team},
75 | title = {Silero-VAD Dataset: a large public Internet-scale dataset for voice activity detection for 6000+ languages},
76 | year = {2024},
77 | publisher = {GitHub},
78 | journal = {GitHub repository},
79 | howpublished = {\url{https://github.com/snakers4/silero-vad/datasets/README.md}},
80 | email = {hello@silero.ai}
81 | }
82 | ```
83 |
84 | [^1]: ``Количество уникальных ISO кодов данного датасета не совпадает с фактическим количеством представленных языков, т.к некоторые близкие языки могут кодироваться одним и тем же ISO кодом.``
85 |
--------------------------------------------------------------------------------
/examples/colab_record_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "bccAucKjnPHm"
7 | },
8 | "source": [
9 | "### Dependencies and inputs"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "cSih95WFmwgi"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "#!apt install ffmpeg\n",
21 | "!pip -q install pydub\n",
22 | "from google.colab import output\n",
23 | "from base64 import b64decode, b64encode\n",
24 | "from io import BytesIO\n",
25 | "import numpy as np\n",
26 | "from pydub import AudioSegment\n",
27 | "from IPython.display import HTML, display\n",
28 | "import torch\n",
29 | "import matplotlib.pyplot as plt\n",
30 | "import moviepy.editor as mpe\n",
31 | "from matplotlib.animation import FuncAnimation, FFMpegWriter\n",
32 | "import matplotlib\n",
33 | "matplotlib.use('Agg')\n",
34 | "\n",
35 | "torch.set_num_threads(1)\n",
36 | "\n",
37 | "model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
38 | " model='silero_vad',\n",
39 | " force_reload=True)\n",
40 | "\n",
41 | "def int2float(audio):\n",
42 | " samples = audio.get_array_of_samples()\n",
43 | " new_sound = audio._spawn(samples)\n",
44 | " arr = np.array(samples).astype(np.float32)\n",
45 | " arr = arr / np.abs(arr).max()\n",
46 | " return arr\n",
47 | "\n",
48 | "AUDIO_HTML = \"\"\"\n",
49 | "\n",
126 | "\"\"\"\n",
127 | "\n",
128 | "def record(sec=10):\n",
129 | " display(HTML(AUDIO_HTML))\n",
130 | " s = output.eval_js(\"data\")\n",
131 | " b = b64decode(s.split(',')[1])\n",
132 | " audio = AudioSegment.from_file(BytesIO(b))\n",
133 | " audio.export('test.mp3', format='mp3')\n",
134 | " audio = audio.set_channels(1)\n",
135 | " audio = audio.set_frame_rate(16000)\n",
136 | " audio_float = int2float(audio)\n",
137 | " audio_tens = torch.tensor(audio_float)\n",
138 | " return audio_tens\n",
139 | "\n",
140 | "def make_animation(probs, audio_duration, interval=40):\n",
141 | " fig = plt.figure(figsize=(16, 9))\n",
142 | " ax = plt.axes(xlim=(0, audio_duration), ylim=(0, 1.02))\n",
143 | " line, = ax.plot([], [], lw=2)\n",
144 | " x = [i / 16000 * 512 for i in range(len(probs))]\n",
145 | " plt.xlabel('Time, seconds', fontsize=16)\n",
146 | " plt.ylabel('Speech Probability', fontsize=16)\n",
147 | "\n",
148 | " def init():\n",
149 | " plt.fill_between(x, probs, color='#064273')\n",
150 | " line.set_data([], [])\n",
151 | " line.set_color('#990000')\n",
152 | " return line,\n",
153 | "\n",
154 | " def animate(i):\n",
155 | " x = i * interval / 1000 - 0.04\n",
156 | " y = np.linspace(0, 1.02, 2)\n",
157 | "\n",
158 | " line.set_data(x, y)\n",
159 | " line.set_color('#990000')\n",
160 | " return line,\n",
161 | " anim = FuncAnimation(fig, animate, init_func=init, interval=interval, save_count=int(audio_duration / (interval / 1000)))\n",
162 | "\n",
163 | " f = r\"animation.mp4\"\n",
164 | " writervideo = FFMpegWriter(fps=1000/interval)\n",
165 | " anim.save(f, writer=writervideo)\n",
166 | " plt.close('all')\n",
167 | "\n",
168 | "def combine_audio(vidname, audname, outname, fps=25):\n",
169 | " my_clip = mpe.VideoFileClip(vidname, verbose=False)\n",
170 | " audio_background = mpe.AudioFileClip(audname)\n",
171 | " final_clip = my_clip.set_audio(audio_background)\n",
172 | " final_clip.write_videofile(outname,fps=fps,verbose=False)\n",
173 | "\n",
174 | "def record_make_animation():\n",
175 | " tensor = record()\n",
176 | " print('Calculating probabilities...')\n",
177 | " speech_probs = []\n",
178 | " window_size_samples = 512\n",
179 | " speech_probs = model.audio_forward(tensor, sr=16000)[0].tolist()\n",
180 | " model.reset_states()\n",
181 | " print('Making animation...')\n",
182 | " make_animation(speech_probs, len(tensor) / 16000)\n",
183 | "\n",
184 | " print('Merging your voice with animation...')\n",
185 | " combine_audio('animation.mp4', 'test.mp3', 'merged.mp4')\n",
186 | " print('Done!')\n",
187 | " mp4 = open('merged.mp4','rb').read()\n",
188 | " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
189 | " display(HTML(\"\"\"\n",
190 | " \n",
191 | " \n",
192 | " \n",
193 | " \"\"\" % data_url))\n",
194 | "\n",
195 | " return speech_probs"
196 | ]
197 | },
198 | {
199 | "cell_type": "markdown",
200 | "metadata": {
201 | "id": "IFVs3GvTnpB1"
202 | },
203 | "source": [
204 | "## Record example"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {
211 | "id": "5EBjrTwiqAaQ"
212 | },
213 | "outputs": [],
214 | "source": [
215 | "speech_probs = record_make_animation()"
216 | ]
217 | }
218 | ],
219 | "metadata": {
220 | "colab": {
221 | "collapsed_sections": [
222 | "bccAucKjnPHm"
223 | ],
224 | "name": "Untitled2.ipynb",
225 | "provenance": []
226 | },
227 | "kernelspec": {
228 | "display_name": "Python 3",
229 | "name": "python3"
230 | },
231 | "language_info": {
232 | "name": "python"
233 | }
234 | },
235 | "nbformat": 4,
236 | "nbformat_minor": 0
237 | }
238 |
--------------------------------------------------------------------------------
/examples/cpp/README.md:
--------------------------------------------------------------------------------
1 | # Stream example in C++
2 |
3 | Here's a simple example of the vad model in c++ onnxruntime.
4 |
5 |
6 |
7 | ## Requirements
8 |
9 | Code are tested in the environments bellow, feel free to try others.
10 |
11 | - WSL2 + Debian-bullseye (docker)
12 | - gcc 12.2.0
13 | - onnxruntime-linux-x64-1.12.1
14 |
15 |
16 |
17 | ## Usage
18 |
19 | 1. Install gcc 12.2.0, or just pull the docker image with `docker pull gcc:12.2.0-bullseye`
20 |
21 | 2. Install onnxruntime-linux-x64-1.12.1
22 |
23 | - Download lib onnxruntime:
24 |
25 | `wget https://github.com/microsoft/onnxruntime/releases/download/v1.12.1/onnxruntime-linux-x64-1.12.1.tgz`
26 |
27 | - Unzip. Assume the path is `/root/onnxruntime-linux-x64-1.12.1`
28 |
29 | 3. Modify wav path & Test configs in main function
30 |
31 | `wav::WavReader wav_reader("${path_to_your_wav_file}");`
32 |
33 | test sample rate, frame per ms, threshold...
34 |
35 | 4. Build with gcc and run
36 |
37 | ```bash
38 | # Build
39 | g++ silero-vad-onnx.cpp -I /root/onnxruntime-linux-x64-1.12.1/include/ -L /root/onnxruntime-linux-x64-1.12.1/lib/ -lonnxruntime -Wl,-rpath,/root/onnxruntime-linux-x64-1.12.1/lib/ -o test
40 |
41 | # Run
42 | ./test
43 | ```
--------------------------------------------------------------------------------
/examples/cpp/wav.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Personal (Binbin Zhang)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 | #ifndef FRONTEND_WAV_H_
16 | #define FRONTEND_WAV_H_
17 |
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include
26 |
27 | #include
28 |
29 | // #include "utils/log.h"
30 |
31 | namespace wav {
32 |
33 | struct WavHeader {
34 | char riff[4]; // "riff"
35 | unsigned int size;
36 | char wav[4]; // "WAVE"
37 | char fmt[4]; // "fmt "
38 | unsigned int fmt_size;
39 | uint16_t format;
40 | uint16_t channels;
41 | unsigned int sample_rate;
42 | unsigned int bytes_per_second;
43 | uint16_t block_size;
44 | uint16_t bit;
45 | char data[4]; // "data"
46 | unsigned int data_size;
47 | };
48 |
49 | class WavReader {
50 | public:
51 | WavReader() : data_(nullptr) {}
52 | explicit WavReader(const std::string& filename) { Open(filename); }
53 |
54 | bool Open(const std::string& filename) {
55 | FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
56 | if (NULL == fp) {
57 | std::cout << "Error in read " << filename;
58 | return false;
59 | }
60 |
61 | WavHeader header;
62 | fread(&header, 1, sizeof(header), fp);
63 | if (header.fmt_size < 16) {
64 | printf("WaveData: expect PCM format data "
65 | "to have fmt chunk of at least size 16.\n");
66 | return false;
67 | } else if (header.fmt_size > 16) {
68 | int offset = 44 - 8 + header.fmt_size - 16;
69 | fseek(fp, offset, SEEK_SET);
70 | fread(header.data, 8, sizeof(char), fp);
71 | }
72 | // check "riff" "WAVE" "fmt " "data"
73 |
74 | // Skip any sub-chunks between "fmt" and "data". Usually there will
75 | // be a single "fact" sub chunk, but on Windows there can also be a
76 | // "list" sub chunk.
77 | while (0 != strncmp(header.data, "data", 4)) {
78 | // We will just ignore the data in these chunks.
79 | fseek(fp, header.data_size, SEEK_CUR);
80 | // read next sub chunk
81 | fread(header.data, 8, sizeof(char), fp);
82 | }
83 |
84 | if (header.data_size == 0) {
85 | int offset = ftell(fp);
86 | fseek(fp, 0, SEEK_END);
87 | header.data_size = ftell(fp) - offset;
88 | fseek(fp, offset, SEEK_SET);
89 | }
90 |
91 | num_channel_ = header.channels;
92 | sample_rate_ = header.sample_rate;
93 | bits_per_sample_ = header.bit;
94 | int num_data = header.data_size / (bits_per_sample_ / 8);
95 | data_ = new float[num_data]; // Create 1-dim array
96 | num_samples_ = num_data / num_channel_;
97 |
98 | std::cout << "num_channel_ :" << num_channel_ << std::endl;
99 | std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
100 | std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
101 | std::cout << "num_samples :" << num_data << std::endl;
102 | std::cout << "num_data_size :" << header.data_size << std::endl;
103 |
104 | switch (bits_per_sample_) {
105 | case 8: {
106 | char sample;
107 | for (int i = 0; i < num_data; ++i) {
108 | fread(&sample, 1, sizeof(char), fp);
109 | data_[i] = static_cast(sample) / 32768;
110 | }
111 | break;
112 | }
113 | case 16: {
114 | int16_t sample;
115 | for (int i = 0; i < num_data; ++i) {
116 | fread(&sample, 1, sizeof(int16_t), fp);
117 | data_[i] = static_cast(sample) / 32768;
118 | }
119 | break;
120 | }
121 | case 32:
122 | {
123 | if (header.format == 1) //S32
124 | {
125 | int sample;
126 | for (int i = 0; i < num_data; ++i) {
127 | fread(&sample, 1, sizeof(int), fp);
128 | data_[i] = static_cast(sample) / 32768;
129 | }
130 | }
131 | else if (header.format == 3) // IEEE-float
132 | {
133 | float sample;
134 | for (int i = 0; i < num_data; ++i) {
135 | fread(&sample, 1, sizeof(float), fp);
136 | data_[i] = static_cast(sample);
137 | }
138 | }
139 | else {
140 | printf("unsupported quantization bits\n");
141 | }
142 | break;
143 | }
144 | default:
145 | printf("unsupported quantization bits\n");
146 | break;
147 | }
148 |
149 | fclose(fp);
150 | return true;
151 | }
152 |
153 | int num_channel() const { return num_channel_; }
154 | int sample_rate() const { return sample_rate_; }
155 | int bits_per_sample() const { return bits_per_sample_; }
156 | int num_samples() const { return num_samples_; }
157 |
158 | ~WavReader() {
159 | delete[] data_;
160 | }
161 |
162 | const float* data() const { return data_; }
163 |
164 | private:
165 | int num_channel_;
166 | int sample_rate_;
167 | int bits_per_sample_;
168 | int num_samples_; // sample points per channel
169 | float* data_;
170 | };
171 |
172 | class WavWriter {
173 | public:
174 | WavWriter(const float* data, int num_samples, int num_channel,
175 | int sample_rate, int bits_per_sample)
176 | : data_(data),
177 | num_samples_(num_samples),
178 | num_channel_(num_channel),
179 | sample_rate_(sample_rate),
180 | bits_per_sample_(bits_per_sample) {}
181 |
182 | void Write(const std::string& filename) {
183 | FILE* fp = fopen(filename.c_str(), "w");
184 | // init char 'riff' 'WAVE' 'fmt ' 'data'
185 | WavHeader header;
186 | char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
187 | 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
188 | 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
189 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
190 | 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
191 | memcpy(&header, wav_header, sizeof(header));
192 | header.channels = num_channel_;
193 | header.bit = bits_per_sample_;
194 | header.sample_rate = sample_rate_;
195 | header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
196 | header.size = sizeof(header) - 8 + header.data_size;
197 | header.bytes_per_second =
198 | sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
199 | header.block_size = num_channel_ * (bits_per_sample_ / 8);
200 |
201 | fwrite(&header, 1, sizeof(header), fp);
202 |
203 | for (int i = 0; i < num_samples_; ++i) {
204 | for (int j = 0; j < num_channel_; ++j) {
205 | switch (bits_per_sample_) {
206 | case 8: {
207 | char sample = static_cast(data_[i * num_channel_ + j]);
208 | fwrite(&sample, 1, sizeof(sample), fp);
209 | break;
210 | }
211 | case 16: {
212 | int16_t sample = static_cast(data_[i * num_channel_ + j]);
213 | fwrite(&sample, 1, sizeof(sample), fp);
214 | break;
215 | }
216 | case 32: {
217 | int sample = static_cast(data_[i * num_channel_ + j]);
218 | fwrite(&sample, 1, sizeof(sample), fp);
219 | break;
220 | }
221 | }
222 | }
223 | }
224 | fclose(fp);
225 | }
226 |
227 | private:
228 | const float* data_;
229 | int num_samples_; // total float points in data_
230 | int num_channel_;
231 | int sample_rate_;
232 | int bits_per_sample_;
233 | };
234 |
235 | } // namespace wav
236 |
237 | #endif // FRONTEND_WAV_H_
238 |
--------------------------------------------------------------------------------
/examples/cpp_libtorch/README.md:
--------------------------------------------------------------------------------
1 | # Silero-VAD V5 in C++ (based on LibTorch)
2 |
3 | This is the source code for Silero-VAD V5 in C++, utilizing LibTorch. The primary implementation is CPU-based, and you should compare its results with the Python version. Only results at 16kHz have been tested.
4 |
5 | Additionally, batch and CUDA inference options are available if you want to explore further. Note that when using batch inference, the speech probabilities may slightly differ from the standard version, likely due to differences in caching. Unlike individual input processing, batch inference may not use the cache from previous chunks. Despite this, batch inference offers significantly faster processing. For optimal performance, consider adjusting the threshold when using batch inference.
6 |
7 | ## Requirements
8 |
9 | - GCC 11.4.0 (GCC >= 5.1)
10 | - LibTorch 1.13.0 (other versions are also acceptable)
11 |
12 | ## Download LibTorch
13 |
14 | ```bash
15 | -CPU Version
16 | wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip
17 | unzip libtorch-shared-with-deps-1.13.0+cpu.zip'
18 |
19 | -CUDA Version
20 | wget https://download.pytorch.org/libtorch/cu116/libtorch-shared-with-deps-1.13.0%2Bcu116.zip
21 | unzip libtorch-shared-with-deps-1.13.0+cu116.zip
22 | ```
23 |
24 | ## Compilation
25 |
26 | ```bash
27 | -CPU Version
28 | g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0
29 |
30 | -CUDA Version
31 | g++ main.cc silero_torch.cc -I ./libtorch/include/ -I ./libtorch/include/torch/csrc/api/include -L ./libtorch/lib/ -ltorch -ltorch_cuda -ltorch_cpu -lc10 -Wl,-rpath,./libtorch/lib/ -o silero -std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -DUSE_GPU
32 | ```
33 |
34 |
35 | ## Optional Compilation Flags
36 | -DUSE_BATCH: Enable batch inference
37 | -DUSE_GPU: Use GPU for inference
38 |
39 | ## Run the Program
40 | To run the program, use the following command:
41 |
42 | `./silero aepyx.wav 16000 0.5`
43 |
44 | The sample file aepyx.wav is part of the Voxconverse dataset.
45 | File details: aepyx.wav is a 16kHz, 16-bit audio file.
46 |
--------------------------------------------------------------------------------
/examples/cpp_libtorch/aepyx.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/examples/cpp_libtorch/aepyx.wav
--------------------------------------------------------------------------------
/examples/cpp_libtorch/main.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include "silero_torch.h"
3 | #include "wav.h"
4 |
5 | int main(int argc, char* argv[]) {
6 |
7 | if(argc != 4){
8 | std::cerr<<"Usage : "< "< input_wav(wav_reader.num_samples());
32 |
33 | for (int i = 0; i < wav_reader.num_samples(); i++)
34 | {
35 | input_wav[i] = static_cast(*(wav_reader.data() + i));
36 | }
37 |
38 | vad.SpeechProbs(input_wav);
39 |
40 | std::vector speeches = vad.GetSpeechTimestamps();
41 | for(const auto& speech : speeches){
42 | if(vad.print_as_samples){
43 | std::cout<<"{'start': "<(speech.start)<<", 'end': "<(speech.end)<<"}"<& input_wav){
22 | // Set the sample rate (must match the model's expected sample rate)
23 | // Process the waveform in chunks of 512 samples
24 | int num_samples = input_wav.size();
25 | int num_chunks = num_samples / window_size_samples;
26 | int remainder_samples = num_samples % window_size_samples;
27 |
28 | total_sample_size += num_samples;
29 |
30 | torch::Tensor output;
31 | std::vector chunks;
32 |
33 | for (int i = 0; i < num_chunks; i++) {
34 |
35 | float* chunk_start = input_wav.data() + i *window_size_samples;
36 | torch::Tensor chunk = torch::from_blob(chunk_start, {1,window_size_samples}, torch::kFloat32);
37 | //std::cout<<"chunk size : "<0){//마지막 chunk && 나머지가 존재
42 | int remaining_samples = num_samples - num_chunks * window_size_samples;
43 | //std::cout<<"Remainder size : "< inputs;
69 | inputs.push_back(batched_chunks); // Batch of chunks
70 | inputs.push_back(sample_rate); // Assuming sample_rate is a valid input for the model
71 |
72 | // Run inference on the batch
73 | torch::NoGradGuard no_grad;
74 | torch::Tensor output = model.forward(inputs).toTensor();
75 | #ifdef USE_GPU
76 | output = output.to(at::kCPU); // Move the output back to CPU once
77 | #endif
78 | // Collect output probabilities
79 | for (int i = 0; i < chunks.size(); i++) {
80 | float output_f = output[i].item();
81 | outputs_prob.push_back(output_f);
82 | //std::cout << "Chunk " << i << " prob: " << output_f<< "\n";
83 | }
84 | #else
85 |
86 | std::vector outputs;
87 | torch::Tensor batched_chunks = torch::stack(chunks);
88 | #ifdef USE_GPU
89 | batched_chunks = batched_chunks.to(at::kCUDA);
90 | #endif
91 | for (int i = 0; i < chunks.size(); i++) {
92 | torch::NoGradGuard no_grad;
93 | std::vector inputs;
94 | inputs.push_back(batched_chunks[i]);
95 | inputs.push_back(sample_rate);
96 |
97 | torch::Tensor output = model.forward(inputs).toTensor();
98 | outputs.push_back(output);
99 | }
100 | torch::Tensor all_outputs = torch::stack(outputs);
101 | #ifdef USE_GPU
102 | all_outputs = all_outputs.to(at::kCPU);
103 | #endif
104 | for (int i = 0; i < chunks.size(); i++) {
105 | float output_f = all_outputs[i].item();
106 | outputs_prob.push_back(output_f);
107 | }
108 |
109 |
110 |
111 | #endif
112 |
113 | }
114 |
115 |
116 | }
117 |
118 |
119 | std::vector VadIterator::GetSpeechTimestamps() {
120 | std::vector speeches = DoVad();
121 |
122 | #ifdef USE_BATCH
123 | //When you use BATCH inference. You would better use 'mergeSpeeches' function to arrage time stamp.
124 | //It could be better get reasonable output because of distorted probs.
125 | duration_merge_samples = sample_rate * max_duration_merge_ms / 1000;
126 | std::vector speeches_merge = mergeSpeeches(speeches, duration_merge_samples);
127 | if(!print_as_samples){
128 | for (auto& speech : speeches_merge) { //samples to second
129 | speech.start /= sample_rate;
130 | speech.end /= sample_rate;
131 | }
132 | }
133 |
134 | return speeches_merge;
135 | #else
136 |
137 | if(!print_as_samples){
138 | for (auto& speech : speeches) { //samples to second
139 | speech.start /= sample_rate;
140 | speech.end /= sample_rate;
141 | }
142 | }
143 |
144 | return speeches;
145 |
146 | #endif
147 |
148 | }
149 | void VadIterator::SetVariables(){
150 | init_engine(window_size_ms);
151 | }
152 |
153 | void VadIterator::init_engine(int window_size_ms) {
154 | min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
155 | speech_pad_samples = sample_rate * speech_pad_ms / 1000;
156 | window_size_samples = sample_rate / 1000 * window_size_ms;
157 | min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
158 | }
159 |
160 | void VadIterator::init_torch_model(const std::string& model_path) {
161 | at::set_num_threads(1);
162 | model = torch::jit::load(model_path);
163 |
164 | #ifdef USE_GPU
165 | if (!torch::cuda::is_available()) {
166 | std::cout<<"CUDA is not available! Please check your GPU settings"< VadIterator::DoVad() {
192 | std::vector speeches;
193 |
194 | for (size_t i = 0; i < outputs_prob.size(); ++i) {
195 | float speech_prob = outputs_prob[i];
196 | //std::cout << speech_prob << std::endl;
197 | //std::cout << "Chunk " << i << " Prob: " << speech_prob << "\n";
198 | //std::cout << speech_prob << " ";
199 | current_sample += window_size_samples;
200 |
201 | if (speech_prob >= threshold && temp_end != 0) {
202 | temp_end = 0;
203 | }
204 |
205 | if (speech_prob >= threshold && !triggered) {
206 | triggered = true;
207 | SpeechSegment segment;
208 | segment.start = std::max(static_cast(0), current_sample - speech_pad_samples - window_size_samples);
209 | speeches.push_back(segment);
210 | continue;
211 | }
212 |
213 | if (speech_prob < threshold - 0.15f && triggered) {
214 | if (temp_end == 0) {
215 | temp_end = current_sample;
216 | }
217 |
218 | if (current_sample - temp_end < min_silence_samples) {
219 | continue;
220 | } else {
221 | SpeechSegment& segment = speeches.back();
222 | segment.end = temp_end + speech_pad_samples - window_size_samples;
223 | temp_end = 0;
224 | triggered = false;
225 | }
226 | }
227 | }
228 |
229 | if (triggered) { //만약 낮은 확률을 보이다가 마지막프레임 prbos만 딱 확률이 높게 나오면 위에서 triggerd = true 메핑과 동시에 segment start가 돼서 문제가 될것 같은데? start = end 같은값? 후처리가 있으니 문제가 없으려나?
230 | std::cout<<"when last triggered is keep working until last Probs"<speech_pad_samples) - (speech.start + this->speech_pad_samples) < min_speech_samples);
242 | //min_speech_samples is 4000samples(0.25sec)
243 | //여기서 포인트!! 계산 할때는 start,end sample에'speech_pad_samples' 사이즈를 추가한후 길이를 측정함.
244 | }
245 | ),
246 | speeches.end()
247 | );
248 |
249 |
250 | //std::cout< VadIterator::mergeSpeeches(const std::vector& speeches, int duration_merge_samples) {
258 | std::vector mergedSpeeches;
259 |
260 | if (speeches.empty()) {
261 | return mergedSpeeches; // 빈 벡터 반환
262 | }
263 |
264 | // 첫 번째 구간으로 초기화
265 | SpeechSegment currentSegment = speeches[0];
266 |
267 | for (size_t i = 1; i < speeches.size(); ++i) { //첫번째 start,end 정보 건너뛰기. 그래서 i=1부터
268 | // 두 구간의 차이가 threshold(duration_merge_samples)보다 작은 경우, 합침
269 | if (speeches[i].start - currentSegment.end < duration_merge_samples) {
270 | // 현재 구간의 끝점을 업데이트
271 | currentSegment.end = speeches[i].end;
272 | } else {
273 | // 차이가 threshold(duration_merge_samples) 이상이면 현재 구간을 저장하고 새로운 구간 시작
274 | mergedSpeeches.push_back(currentSegment);
275 | currentSegment = speeches[i];
276 | }
277 | }
278 |
279 | // 마지막 구간 추가
280 | mergedSpeeches.push_back(currentSegment);
281 |
282 | return mergedSpeeches;
283 | }
284 |
285 | }
286 |
--------------------------------------------------------------------------------
/examples/cpp_libtorch/silero_torch.h:
--------------------------------------------------------------------------------
1 | //Author : Nathan Lee
2 | //Created On : 2024-11-18
3 | //Description : silero 5.1 system for torch-script(c++).
4 | //Version : 1.0
5 |
6 | #ifndef SILERO_TORCH_H
7 | #define SILERO_TORCH_H
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | #include
19 | #include
20 |
21 |
22 | namespace silero{
23 |
24 | struct SpeechSegment{
25 | int start;
26 | int end;
27 | };
28 |
29 | class VadIterator{
30 | public:
31 |
32 | VadIterator(const std::string &model_path, float threshold = 0.5, int sample_rate = 16000,
33 | int window_size_ms = 32, int speech_pad_ms = 30, int min_silence_duration_ms = 100,
34 | int min_speech_duration_ms = 250, int max_duration_merge_ms = 300, bool print_as_samples = false);
35 | ~VadIterator();
36 |
37 |
38 | void SpeechProbs(std::vector& input_wav);
39 | std::vector GetSpeechTimestamps();
40 | void SetVariables();
41 |
42 | float threshold;
43 | int sample_rate;
44 | int window_size_ms;
45 | int min_speech_duration_ms;
46 | int max_duration_merge_ms;
47 | bool print_as_samples;
48 |
49 | private:
50 | torch::jit::script::Module model;
51 | std::vector outputs_prob;
52 | int min_silence_samples;
53 | int min_speech_samples;
54 | int speech_pad_samples;
55 | int window_size_samples;
56 | int duration_merge_samples;
57 | int current_sample = 0;
58 |
59 | int total_sample_size=0;
60 |
61 | int min_silence_duration_ms;
62 | int speech_pad_ms;
63 | bool triggered = false;
64 | int temp_end = 0;
65 |
66 | void init_engine(int window_size_ms);
67 | void init_torch_model(const std::string& model_path);
68 | void reset_states();
69 | std::vector DoVad();
70 | std::vector mergeSpeeches(const std::vector& speeches, int duration_merge_samples);
71 |
72 | };
73 |
74 | }
75 | #endif // SILERO_TORCH_H
76 |
--------------------------------------------------------------------------------
/examples/cpp_libtorch/wav.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2016 Personal (Binbin Zhang)
2 | //
3 | // Licensed under the Apache License, Version 2.0 (the "License");
4 | // you may not use this file except in compliance with the License.
5 | // You may obtain a copy of the License at
6 | //
7 | // http://www.apache.org/licenses/LICENSE-2.0
8 | //
9 | // Unless required by applicable law or agreed to in writing, software
10 | // distributed under the License is distributed on an "AS IS" BASIS,
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | // See the License for the specific language governing permissions and
13 | // limitations under the License.
14 |
15 |
16 | #ifndef FRONTEND_WAV_H_
17 | #define FRONTEND_WAV_H_
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #include
26 |
27 | // #include "utils/log.h"
28 |
29 | namespace wav {
30 |
31 | struct WavHeader {
32 | char riff[4]; // "riff"
33 | unsigned int size;
34 | char wav[4]; // "WAVE"
35 | char fmt[4]; // "fmt "
36 | unsigned int fmt_size;
37 | uint16_t format;
38 | uint16_t channels;
39 | unsigned int sample_rate;
40 | unsigned int bytes_per_second;
41 | uint16_t block_size;
42 | uint16_t bit;
43 | char data[4]; // "data"
44 | unsigned int data_size;
45 | };
46 |
47 | class WavReader {
48 | public:
49 | WavReader() : data_(nullptr) {}
50 | explicit WavReader(const std::string& filename) { Open(filename); }
51 |
52 | bool Open(const std::string& filename) {
53 | FILE* fp = fopen(filename.c_str(), "rb"); //文件读取
54 | if (NULL == fp) {
55 | std::cout << "Error in read " << filename;
56 | return false;
57 | }
58 |
59 | WavHeader header;
60 | fread(&header, 1, sizeof(header), fp);
61 | if (header.fmt_size < 16) {
62 | printf("WaveData: expect PCM format data "
63 | "to have fmt chunk of at least size 16.\n");
64 | return false;
65 | } else if (header.fmt_size > 16) {
66 | int offset = 44 - 8 + header.fmt_size - 16;
67 | fseek(fp, offset, SEEK_SET);
68 | fread(header.data, 8, sizeof(char), fp);
69 | }
70 | // check "riff" "WAVE" "fmt " "data"
71 |
72 | // Skip any sub-chunks between "fmt" and "data". Usually there will
73 | // be a single "fact" sub chunk, but on Windows there can also be a
74 | // "list" sub chunk.
75 | while (0 != strncmp(header.data, "data", 4)) {
76 | // We will just ignore the data in these chunks.
77 | fseek(fp, header.data_size, SEEK_CUR);
78 | // read next sub chunk
79 | fread(header.data, 8, sizeof(char), fp);
80 | }
81 |
82 | if (header.data_size == 0) {
83 | int offset = ftell(fp);
84 | fseek(fp, 0, SEEK_END);
85 | header.data_size = ftell(fp) - offset;
86 | fseek(fp, offset, SEEK_SET);
87 | }
88 |
89 | num_channel_ = header.channels;
90 | sample_rate_ = header.sample_rate;
91 | bits_per_sample_ = header.bit;
92 | int num_data = header.data_size / (bits_per_sample_ / 8);
93 | data_ = new float[num_data]; // Create 1-dim array
94 | num_samples_ = num_data / num_channel_;
95 |
96 | std::cout << "num_channel_ :" << num_channel_ << std::endl;
97 | std::cout << "sample_rate_ :" << sample_rate_ << std::endl;
98 | std::cout << "bits_per_sample_:" << bits_per_sample_ << std::endl;
99 | std::cout << "num_samples :" << num_data << std::endl;
100 | std::cout << "num_data_size :" << header.data_size << std::endl;
101 |
102 | switch (bits_per_sample_) {
103 | case 8: {
104 | char sample;
105 | for (int i = 0; i < num_data; ++i) {
106 | fread(&sample, 1, sizeof(char), fp);
107 | data_[i] = static_cast(sample) / 32768;
108 | }
109 | break;
110 | }
111 | case 16: {
112 | int16_t sample;
113 | for (int i = 0; i < num_data; ++i) {
114 | fread(&sample, 1, sizeof(int16_t), fp);
115 | data_[i] = static_cast(sample) / 32768;
116 | }
117 | break;
118 | }
119 | case 32:
120 | {
121 | if (header.format == 1) //S32
122 | {
123 | int sample;
124 | for (int i = 0; i < num_data; ++i) {
125 | fread(&sample, 1, sizeof(int), fp);
126 | data_[i] = static_cast(sample) / 32768;
127 | }
128 | }
129 | else if (header.format == 3) // IEEE-float
130 | {
131 | float sample;
132 | for (int i = 0; i < num_data; ++i) {
133 | fread(&sample, 1, sizeof(float), fp);
134 | data_[i] = static_cast(sample);
135 | }
136 | }
137 | else {
138 | printf("unsupported quantization bits\n");
139 | }
140 | break;
141 | }
142 | default:
143 | printf("unsupported quantization bits\n");
144 | break;
145 | }
146 |
147 | fclose(fp);
148 | return true;
149 | }
150 |
151 | int num_channel() const { return num_channel_; }
152 | int sample_rate() const { return sample_rate_; }
153 | int bits_per_sample() const { return bits_per_sample_; }
154 | int num_samples() const { return num_samples_; }
155 |
156 | ~WavReader() {
157 | delete[] data_;
158 | }
159 |
160 | const float* data() const { return data_; }
161 |
162 | private:
163 | int num_channel_;
164 | int sample_rate_;
165 | int bits_per_sample_;
166 | int num_samples_; // sample points per channel
167 | float* data_;
168 | };
169 |
170 | class WavWriter {
171 | public:
172 | WavWriter(const float* data, int num_samples, int num_channel,
173 | int sample_rate, int bits_per_sample)
174 | : data_(data),
175 | num_samples_(num_samples),
176 | num_channel_(num_channel),
177 | sample_rate_(sample_rate),
178 | bits_per_sample_(bits_per_sample) {}
179 |
180 | void Write(const std::string& filename) {
181 | FILE* fp = fopen(filename.c_str(), "w");
182 | // init char 'riff' 'WAVE' 'fmt ' 'data'
183 | WavHeader header;
184 | char wav_header[44] = {0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57,
185 | 0x41, 0x56, 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00,
186 | 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
187 | 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
188 | 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00};
189 | memcpy(&header, wav_header, sizeof(header));
190 | header.channels = num_channel_;
191 | header.bit = bits_per_sample_;
192 | header.sample_rate = sample_rate_;
193 | header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8);
194 | header.size = sizeof(header) - 8 + header.data_size;
195 | header.bytes_per_second =
196 | sample_rate_ * num_channel_ * (bits_per_sample_ / 8);
197 | header.block_size = num_channel_ * (bits_per_sample_ / 8);
198 |
199 | fwrite(&header, 1, sizeof(header), fp);
200 |
201 | for (int i = 0; i < num_samples_; ++i) {
202 | for (int j = 0; j < num_channel_; ++j) {
203 | switch (bits_per_sample_) {
204 | case 8: {
205 | char sample = static_cast(data_[i * num_channel_ + j]);
206 | fwrite(&sample, 1, sizeof(sample), fp);
207 | break;
208 | }
209 | case 16: {
210 | int16_t sample = static_cast(data_[i * num_channel_ + j]);
211 | fwrite(&sample, 1, sizeof(sample), fp);
212 | break;
213 | }
214 | case 32: {
215 | int sample = static_cast(data_[i * num_channel_ + j]);
216 | fwrite(&sample, 1, sizeof(sample), fp);
217 | break;
218 | }
219 | }
220 | }
221 | }
222 | fclose(fp);
223 | }
224 |
225 | private:
226 | const float* data_;
227 | int num_samples_; // total float points in data_
228 | int num_channel_;
229 | int sample_rate_;
230 | int bits_per_sample_;
231 | };
232 |
233 | } // namespace wenet
234 |
235 | #endif // FRONTEND_WAV_H_
236 |
--------------------------------------------------------------------------------
/examples/csharp/Program.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 |
3 | namespace VadDotNet;
4 |
5 |
6 | class Program
7 | {
8 | private const string MODEL_PATH = "./resources/silero_vad.onnx";
9 | private const string EXAMPLE_WAV_FILE = "./resources/example.wav";
10 | private const int SAMPLE_RATE = 16000;
11 | private const float THRESHOLD = 0.5f;
12 | private const int MIN_SPEECH_DURATION_MS = 250;
13 | private const float MAX_SPEECH_DURATION_SECONDS = float.PositiveInfinity;
14 | private const int MIN_SILENCE_DURATION_MS = 100;
15 | private const int SPEECH_PAD_MS = 30;
16 |
17 | public static void Main(string[] args)
18 | {
19 |
20 | var vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
21 | MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
22 | List speechTimeList = vadDetector.GetSpeechSegmentList(new FileInfo(EXAMPLE_WAV_FILE));
23 | //Console.WriteLine(speechTimeList.ToJson());
24 | StringBuilder sb = new StringBuilder();
25 | foreach (var speechSegment in speechTimeList)
26 | {
27 | sb.Append($"start second: {speechSegment.StartSecond}, end second: {speechSegment.EndSecond}\n");
28 |
29 | }
30 | Console.WriteLine(sb.ToString());
31 |
32 | }
33 |
34 |
35 | }
36 |
--------------------------------------------------------------------------------
/examples/csharp/SileroSpeechSegment.cs:
--------------------------------------------------------------------------------
1 | namespace VadDotNet;
2 |
3 | public class SileroSpeechSegment
4 | {
5 | public int? StartOffset { get; set; }
6 | public int? EndOffset { get; set; }
7 | public float? StartSecond { get; set; }
8 | public float? EndSecond { get; set; }
9 |
10 | public SileroSpeechSegment()
11 | {
12 | }
13 |
14 | public SileroSpeechSegment(int startOffset, int? endOffset, float? startSecond, float? endSecond)
15 | {
16 | StartOffset = startOffset;
17 | EndOffset = endOffset;
18 | StartSecond = startSecond;
19 | EndSecond = endSecond;
20 | }
21 | }
--------------------------------------------------------------------------------
/examples/csharp/SileroVadDetector.cs:
--------------------------------------------------------------------------------
1 | using NAudio.Wave;
2 | using VADdotnet;
3 |
4 | namespace VadDotNet;
5 |
6 | public class SileroVadDetector
7 | {
8 | private readonly SileroVadOnnxModel _model;
9 | private readonly float _threshold;
10 | private readonly float _negThreshold;
11 | private readonly int _samplingRate;
12 | private readonly int _windowSizeSample;
13 | private readonly float _minSpeechSamples;
14 | private readonly float _speechPadSamples;
15 | private readonly float _maxSpeechSamples;
16 | private readonly float _minSilenceSamples;
17 | private readonly float _minSilenceSamplesAtMaxSpeech;
18 | private int _audioLengthSamples;
19 | private const float THRESHOLD_GAP = 0.15f;
20 | // ReSharper disable once InconsistentNaming
21 | private const int SAMPLING_RATE_8K = 8000;
22 | // ReSharper disable once InconsistentNaming
23 | private const int SAMPLING_RATE_16K = 16000;
24 |
25 | public SileroVadDetector(string onnxModelPath, float threshold, int samplingRate,
26 | int minSpeechDurationMs, float maxSpeechDurationSeconds,
27 | int minSilenceDurationMs, int speechPadMs)
28 | {
29 | if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K)
30 | {
31 | throw new ArgumentException("Sampling rate not support, only available for [8000, 16000]");
32 | }
33 |
34 | this._model = new SileroVadOnnxModel(onnxModelPath);
35 | this._samplingRate = samplingRate;
36 | this._threshold = threshold;
37 | this._negThreshold = threshold - THRESHOLD_GAP;
38 | this._windowSizeSample = samplingRate == SAMPLING_RATE_16K ? 512 : 256;
39 | this._minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
40 | this._speechPadSamples = samplingRate * speechPadMs / 1000f;
41 | this._maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - _windowSizeSample - 2 * _speechPadSamples;
42 | this._minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
43 | this._minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
44 | this.Reset();
45 | }
46 |
47 | public void Reset()
48 | {
49 | _model.ResetStates();
50 | }
51 |
52 | public List GetSpeechSegmentList(FileInfo wavFile)
53 | {
54 | Reset();
55 |
56 | using (var audioFile = new AudioFileReader(wavFile.FullName))
57 | {
58 | List speechProbList = new List();
59 | this._audioLengthSamples = (int)(audioFile.Length / 2);
60 | float[] buffer = new float[this._windowSizeSample];
61 |
62 | while (audioFile.Read(buffer, 0, buffer.Length) > 0)
63 | {
64 | float speechProb = _model.Call(new[] { buffer }, _samplingRate)[0];
65 | speechProbList.Add(speechProb);
66 | }
67 |
68 | return CalculateProb(speechProbList);
69 | }
70 | }
71 |
72 | private List CalculateProb(List speechProbList)
73 | {
74 | List result = new List();
75 | bool triggered = false;
76 | int tempEnd = 0, prevEnd = 0, nextStart = 0;
77 | SileroSpeechSegment segment = new SileroSpeechSegment();
78 |
79 | for (int i = 0; i < speechProbList.Count; i++)
80 | {
81 | float speechProb = speechProbList[i];
82 | if (speechProb >= _threshold && (tempEnd != 0))
83 | {
84 | tempEnd = 0;
85 | if (nextStart < prevEnd)
86 | {
87 | nextStart = _windowSizeSample * i;
88 | }
89 | }
90 |
91 | if (speechProb >= _threshold && !triggered)
92 | {
93 | triggered = true;
94 | segment.StartOffset = _windowSizeSample * i;
95 | continue;
96 | }
97 |
98 | if (triggered && (_windowSizeSample * i) - segment.StartOffset > _maxSpeechSamples)
99 | {
100 | if (prevEnd != 0)
101 | {
102 | segment.EndOffset = prevEnd;
103 | result.Add(segment);
104 | segment = new SileroSpeechSegment();
105 | if (nextStart < prevEnd)
106 | {
107 | triggered = false;
108 | }
109 | else
110 | {
111 | segment.StartOffset = nextStart;
112 | }
113 |
114 | prevEnd = 0;
115 | nextStart = 0;
116 | tempEnd = 0;
117 | }
118 | else
119 | {
120 | segment.EndOffset = _windowSizeSample * i;
121 | result.Add(segment);
122 | segment = new SileroSpeechSegment();
123 | prevEnd = 0;
124 | nextStart = 0;
125 | tempEnd = 0;
126 | triggered = false;
127 | continue;
128 | }
129 | }
130 |
131 | if (speechProb < _negThreshold && triggered)
132 | {
133 | if (tempEnd == 0)
134 | {
135 | tempEnd = _windowSizeSample * i;
136 | }
137 |
138 | if (((_windowSizeSample * i) - tempEnd) > _minSilenceSamplesAtMaxSpeech)
139 | {
140 | prevEnd = tempEnd;
141 | }
142 |
143 | if ((_windowSizeSample * i) - tempEnd < _minSilenceSamples)
144 | {
145 | continue;
146 | }
147 | else
148 | {
149 | segment.EndOffset = tempEnd;
150 | if ((segment.EndOffset - segment.StartOffset) > _minSpeechSamples)
151 | {
152 | result.Add(segment);
153 | }
154 |
155 | segment = new SileroSpeechSegment();
156 | prevEnd = 0;
157 | nextStart = 0;
158 | tempEnd = 0;
159 | triggered = false;
160 | continue;
161 | }
162 | }
163 | }
164 |
165 | if (segment.StartOffset != null && (_audioLengthSamples - segment.StartOffset) > _minSpeechSamples)
166 | {
167 | segment.EndOffset = _audioLengthSamples;
168 | result.Add(segment);
169 | }
170 |
171 | for (int i = 0; i < result.Count; i++)
172 | {
173 | SileroSpeechSegment item = result[i];
174 | if (i == 0)
175 | {
176 | item.StartOffset = (int)Math.Max(0, item.StartOffset.Value - _speechPadSamples);
177 | }
178 |
179 | if (i != result.Count - 1)
180 | {
181 | SileroSpeechSegment nextItem = result[i + 1];
182 | int silenceDuration = nextItem.StartOffset.Value - item.EndOffset.Value;
183 | if (silenceDuration < 2 * _speechPadSamples)
184 | {
185 | item.EndOffset = item.EndOffset + (silenceDuration / 2);
186 | nextItem.StartOffset = Math.Max(0, nextItem.StartOffset.Value - (silenceDuration / 2));
187 | }
188 | else
189 | {
190 | item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
191 | nextItem.StartOffset = (int)Math.Max(0, nextItem.StartOffset.Value - _speechPadSamples);
192 | }
193 | }
194 | else
195 | {
196 | item.EndOffset = (int)Math.Min(_audioLengthSamples, item.EndOffset.Value + _speechPadSamples);
197 | }
198 | }
199 |
200 | return MergeListAndCalculateSecond(result, _samplingRate);
201 | }
202 |
203 | private List MergeListAndCalculateSecond(List original, int samplingRate)
204 | {
205 | List result = new List();
206 | if (original == null || original.Count == 0)
207 | {
208 | return result;
209 | }
210 |
211 | int left = original[0].StartOffset.Value;
212 | int right = original[0].EndOffset.Value;
213 | if (original.Count > 1)
214 | {
215 | original.Sort((a, b) => a.StartOffset.Value.CompareTo(b.StartOffset.Value));
216 | for (int i = 1; i < original.Count; i++)
217 | {
218 | SileroSpeechSegment segment = original[i];
219 |
220 | if (segment.StartOffset > right)
221 | {
222 | result.Add(new SileroSpeechSegment(left, right,
223 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
224 | left = segment.StartOffset.Value;
225 | right = segment.EndOffset.Value;
226 | }
227 | else
228 | {
229 | right = Math.Max(right, segment.EndOffset.Value);
230 | }
231 | }
232 |
233 | result.Add(new SileroSpeechSegment(left, right,
234 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
235 | }
236 | else
237 | {
238 | result.Add(new SileroSpeechSegment(left, right,
239 | CalculateSecondByOffset(left, samplingRate), CalculateSecondByOffset(right, samplingRate)));
240 | }
241 |
242 | return result;
243 | }
244 |
245 | private float CalculateSecondByOffset(int offset, int samplingRate)
246 | {
247 | float secondValue = offset * 1.0f / samplingRate;
248 | return (float)Math.Floor(secondValue * 1000.0f) / 1000.0f;
249 | }
250 | }
--------------------------------------------------------------------------------
/examples/csharp/SileroVadOnnxModel.cs:
--------------------------------------------------------------------------------
1 | using Microsoft.ML.OnnxRuntime;
2 | using Microsoft.ML.OnnxRuntime.Tensors;
3 | using System;
4 | using System.Collections.Generic;
5 | using System.Linq;
6 |
7 | namespace VADdotnet;
8 |
9 |
10 | public class SileroVadOnnxModel : IDisposable
11 | {
12 | private readonly InferenceSession session;
13 | private float[][][] state;
14 | private float[][] context;
15 | private int lastSr = 0;
16 | private int lastBatchSize = 0;
17 | private static readonly List SAMPLE_RATES = new List { 8000, 16000 };
18 |
19 | public SileroVadOnnxModel(string modelPath)
20 | {
21 | var sessionOptions = new SessionOptions();
22 | sessionOptions.InterOpNumThreads = 1;
23 | sessionOptions.IntraOpNumThreads = 1;
24 | sessionOptions.EnableCpuMemArena = true;
25 |
26 | session = new InferenceSession(modelPath, sessionOptions);
27 | ResetStates();
28 | }
29 |
30 | public void ResetStates()
31 | {
32 | state = new float[2][][];
33 | state[0] = new float[1][];
34 | state[1] = new float[1][];
35 | state[0][0] = new float[128];
36 | state[1][0] = new float[128];
37 | context = Array.Empty();
38 | lastSr = 0;
39 | lastBatchSize = 0;
40 | }
41 |
42 | public void Dispose()
43 | {
44 | session?.Dispose();
45 | }
46 |
47 | public class ValidationResult
48 | {
49 | public float[][] X { get; }
50 | public int Sr { get; }
51 |
52 | public ValidationResult(float[][] x, int sr)
53 | {
54 | X = x;
55 | Sr = sr;
56 | }
57 | }
58 |
59 | private ValidationResult ValidateInput(float[][] x, int sr)
60 | {
61 | if (x.Length == 1)
62 | {
63 | x = new float[][] { x[0] };
64 | }
65 | if (x.Length > 2)
66 | {
67 | throw new ArgumentException($"Incorrect audio data dimension: {x[0].Length}");
68 | }
69 |
70 | if (sr != 16000 && (sr % 16000 == 0))
71 | {
72 | int step = sr / 16000;
73 | float[][] reducedX = new float[x.Length][];
74 |
75 | for (int i = 0; i < x.Length; i++)
76 | {
77 | float[] current = x[i];
78 | float[] newArr = new float[(current.Length + step - 1) / step];
79 |
80 | for (int j = 0, index = 0; j < current.Length; j += step, index++)
81 | {
82 | newArr[index] = current[j];
83 | }
84 |
85 | reducedX[i] = newArr;
86 | }
87 |
88 | x = reducedX;
89 | sr = 16000;
90 | }
91 |
92 | if (!SAMPLE_RATES.Contains(sr))
93 | {
94 | throw new ArgumentException($"Only supports sample rates {string.Join(", ", SAMPLE_RATES)} (or multiples of 16000)");
95 | }
96 |
97 | if (((float)sr) / x[0].Length > 31.25)
98 | {
99 | throw new ArgumentException("Input audio is too short");
100 | }
101 |
102 | return new ValidationResult(x, sr);
103 | }
104 |
105 | private static float[][] Concatenate(float[][] a, float[][] b)
106 | {
107 | if (a.Length != b.Length)
108 | {
109 | throw new ArgumentException("The number of rows in both arrays must be the same.");
110 | }
111 |
112 | int rows = a.Length;
113 | int colsA = a[0].Length;
114 | int colsB = b[0].Length;
115 | float[][] result = new float[rows][];
116 |
117 | for (int i = 0; i < rows; i++)
118 | {
119 | result[i] = new float[colsA + colsB];
120 | Array.Copy(a[i], 0, result[i], 0, colsA);
121 | Array.Copy(b[i], 0, result[i], colsA, colsB);
122 | }
123 |
124 | return result;
125 | }
126 |
127 | private static float[][] GetLastColumns(float[][] array, int contextSize)
128 | {
129 | int rows = array.Length;
130 | int cols = array[0].Length;
131 |
132 | if (contextSize > cols)
133 | {
134 | throw new ArgumentException("contextSize cannot be greater than the number of columns in the array.");
135 | }
136 |
137 | float[][] result = new float[rows][];
138 |
139 | for (int i = 0; i < rows; i++)
140 | {
141 | result[i] = new float[contextSize];
142 | Array.Copy(array[i], cols - contextSize, result[i], 0, contextSize);
143 | }
144 |
145 | return result;
146 | }
147 |
148 | public float[] Call(float[][] x, int sr)
149 | {
150 | var result = ValidateInput(x, sr);
151 | x = result.X;
152 | sr = result.Sr;
153 | int numberSamples = sr == 16000 ? 512 : 256;
154 |
155 | if (x[0].Length != numberSamples)
156 | {
157 | throw new ArgumentException($"Provided number of samples is {x[0].Length} (Supported values: 256 for 8000 sample rate, 512 for 16000)");
158 | }
159 |
160 | int batchSize = x.Length;
161 | int contextSize = sr == 16000 ? 64 : 32;
162 |
163 | if (lastBatchSize == 0)
164 | {
165 | ResetStates();
166 | }
167 | if (lastSr != 0 && lastSr != sr)
168 | {
169 | ResetStates();
170 | }
171 | if (lastBatchSize != 0 && lastBatchSize != batchSize)
172 | {
173 | ResetStates();
174 | }
175 |
176 | if (context.Length == 0)
177 | {
178 | context = new float[batchSize][];
179 | for (int i = 0; i < batchSize; i++)
180 | {
181 | context[i] = new float[contextSize];
182 | }
183 | }
184 |
185 | x = Concatenate(context, x);
186 |
187 | var inputs = new List
188 | {
189 | NamedOnnxValue.CreateFromTensor("input", new DenseTensor(x.SelectMany(a => a).ToArray(), new[] { x.Length, x[0].Length })),
190 | NamedOnnxValue.CreateFromTensor("sr", new DenseTensor(new[] { (long)sr }, new[] { 1 })),
191 | NamedOnnxValue.CreateFromTensor("state", new DenseTensor(state.SelectMany(a => a.SelectMany(b => b)).ToArray(), new[] { state.Length, state[0].Length, state[0][0].Length }))
192 | };
193 |
194 | using (var outputs = session.Run(inputs))
195 | {
196 | var output = outputs.First(o => o.Name == "output").AsTensor();
197 | var newState = outputs.First(o => o.Name == "stateN").AsTensor();
198 |
199 | context = GetLastColumns(x, contextSize);
200 | lastSr = sr;
201 | lastBatchSize = batchSize;
202 |
203 | state = new float[newState.Dimensions[0]][][];
204 | for (int i = 0; i < newState.Dimensions[0]; i++)
205 | {
206 | state[i] = new float[newState.Dimensions[1]][];
207 | for (int j = 0; j < newState.Dimensions[1]; j++)
208 | {
209 | state[i][j] = new float[newState.Dimensions[2]];
210 | for (int k = 0; k < newState.Dimensions[2]; k++)
211 | {
212 | state[i][j][k] = newState[i, j, k];
213 | }
214 | }
215 | }
216 |
217 | return output.ToArray();
218 | }
219 | }
220 | }
221 |
--------------------------------------------------------------------------------
/examples/csharp/VadDotNet.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Exe
5 | net8.0
6 | enable
7 | enable
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | PreserveNewest
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/examples/csharp/resources/put_model_here.txt:
--------------------------------------------------------------------------------
1 | place onnx model file and example.wav file in this folder
2 |
--------------------------------------------------------------------------------
/examples/go/README.md:
--------------------------------------------------------------------------------
1 | ## Golang Example
2 |
3 | This is a sample program of how to run speech detection using `silero-vad` from Golang (CGO + ONNX Runtime).
4 |
5 | ### Requirements
6 |
7 | - Golang >= v1.21
8 | - ONNX Runtime
9 |
10 | ### Usage
11 |
12 | ```sh
13 | go run ./cmd/main.go test.wav
14 | ```
15 |
16 | > **_Note_**
17 | >
18 | > Make sure you have the ONNX Runtime library and C headers installed in your path.
19 |
20 |
--------------------------------------------------------------------------------
/examples/go/cmd/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "log"
5 | "os"
6 |
7 | "github.com/streamer45/silero-vad-go/speech"
8 |
9 | "github.com/go-audio/wav"
10 | )
11 |
12 | func main() {
13 | sd, err := speech.NewDetector(speech.DetectorConfig{
14 | ModelPath: "../../src/silero_vad/data/silero_vad.onnx",
15 | SampleRate: 16000,
16 | Threshold: 0.5,
17 | MinSilenceDurationMs: 100,
18 | SpeechPadMs: 30,
19 | })
20 | if err != nil {
21 | log.Fatalf("failed to create speech detector: %s", err)
22 | }
23 |
24 | if len(os.Args) != 2 {
25 | log.Fatalf("invalid arguments provided: expecting one file path")
26 | }
27 |
28 | f, err := os.Open(os.Args[1])
29 | if err != nil {
30 | log.Fatalf("failed to open sample audio file: %s", err)
31 | }
32 | defer f.Close()
33 |
34 | dec := wav.NewDecoder(f)
35 |
36 | if ok := dec.IsValidFile(); !ok {
37 | log.Fatalf("invalid WAV file")
38 | }
39 |
40 | buf, err := dec.FullPCMBuffer()
41 | if err != nil {
42 | log.Fatalf("failed to get PCM buffer")
43 | }
44 |
45 | pcmBuf := buf.AsFloat32Buffer()
46 |
47 | segments, err := sd.Detect(pcmBuf.Data)
48 | if err != nil {
49 | log.Fatalf("Detect failed: %s", err)
50 | }
51 |
52 | for _, s := range segments {
53 | log.Printf("speech starts at %0.2fs", s.SpeechStartAt)
54 | if s.SpeechEndAt > 0 {
55 | log.Printf("speech ends at %0.2fs", s.SpeechEndAt)
56 | }
57 | }
58 |
59 | err = sd.Destroy()
60 | if err != nil {
61 | log.Fatalf("failed to destroy detector: %s", err)
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/examples/go/go.mod:
--------------------------------------------------------------------------------
1 | module silero
2 |
3 | go 1.21.4
4 |
5 | require (
6 | github.com/go-audio/wav v1.1.0
7 | github.com/streamer45/silero-vad-go v0.2.1
8 | )
9 |
10 | require (
11 | github.com/go-audio/audio v1.0.0 // indirect
12 | github.com/go-audio/riff v1.0.0 // indirect
13 | )
14 |
--------------------------------------------------------------------------------
/examples/go/go.sum:
--------------------------------------------------------------------------------
1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
3 | github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4=
4 | github.com/go-audio/audio v1.0.0/go.mod h1:6uAu0+H2lHkwdGsAY+j2wHPNPpPoeg5AaEFh9FlA+Zs=
5 | github.com/go-audio/riff v1.0.0 h1:d8iCGbDvox9BfLagY94fBynxSPHO80LmZCaOsmKxokA=
6 | github.com/go-audio/riff v1.0.0/go.mod h1:l3cQwc85y79NQFCRB7TiPoNiaijp6q8Z0Uv38rVG498=
7 | github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g=
8 | github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE=
9 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
10 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
11 | github.com/streamer45/silero-vad-go v0.2.0 h1:bbRTa6cQuc7VI88y0qicx375UyWoxE6wlVOF+mUg0+g=
12 | github.com/streamer45/silero-vad-go v0.2.0/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
13 | github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
14 | github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
15 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
16 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
17 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
18 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
19 |
--------------------------------------------------------------------------------
/examples/haskell/README.md:
--------------------------------------------------------------------------------
1 | # Haskell example
2 |
3 | To run the example, make sure you put an ``example.wav`` in this directory, and then run the following:
4 | ```bash
5 | stack run
6 | ```
7 |
8 | The ``example.wav`` file must have the following requirements:
9 | - Must be 16khz sample rate.
10 | - Must be mono channel.
11 | - Must be 16-bit audio.
12 |
13 | This uses the [silero-vad](https://hackage.haskell.org/package/silero-vad) package, a haskell implementation based on the C# example.
--------------------------------------------------------------------------------
/examples/haskell/app/Main.hs:
--------------------------------------------------------------------------------
1 | module Main (main) where
2 |
3 | import qualified Data.Vector.Storable as Vector
4 | import Data.WAVE
5 | import Data.Function
6 | import Silero
7 |
8 | main :: IO ()
9 | main =
10 | withModel $ \model -> do
11 | wav <- getWAVEFile "example.wav"
12 | let samples =
13 | concat (waveSamples wav)
14 | & Vector.fromList
15 | & Vector.map (realToFrac . sampleToDouble)
16 | let vad =
17 | (defaultVad model)
18 | { startThreshold = 0.5
19 | , endThreshold = 0.35
20 | }
21 | segments <- detectSegments vad samples
22 | print segments
--------------------------------------------------------------------------------
/examples/haskell/example.cabal:
--------------------------------------------------------------------------------
1 | cabal-version: 1.12
2 |
3 | -- This file has been generated from package.yaml by hpack version 0.37.0.
4 | --
5 | -- see: https://github.com/sol/hpack
6 |
7 | name: example
8 | version: 0.1.0.0
9 | build-type: Simple
10 |
11 | executable example-exe
12 | main-is: Main.hs
13 | other-modules:
14 | Paths_example
15 | hs-source-dirs:
16 | app
17 | ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
18 | build-depends:
19 | WAVE
20 | , base >=4.7 && <5
21 | , silero-vad
22 | , vector
23 | default-language: Haskell2010
24 |
--------------------------------------------------------------------------------
/examples/haskell/package.yaml:
--------------------------------------------------------------------------------
1 | name: example
2 | version: 0.1.0.0
3 |
4 | dependencies:
5 | - base >= 4.7 && < 5
6 | - silero-vad
7 | - WAVE
8 | - vector
9 |
10 | ghc-options:
11 | - -Wall
12 | - -Wcompat
13 | - -Widentities
14 | - -Wincomplete-record-updates
15 | - -Wincomplete-uni-patterns
16 | - -Wmissing-export-lists
17 | - -Wmissing-home-modules
18 | - -Wpartial-fields
19 | - -Wredundant-constraints
20 |
21 | executables:
22 | example-exe:
23 | main: Main.hs
24 | source-dirs: app
25 | ghc-options:
26 | - -threaded
27 | - -rtsopts
28 | - -with-rtsopts=-N
--------------------------------------------------------------------------------
/examples/haskell/stack.yaml:
--------------------------------------------------------------------------------
1 | snapshot:
2 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
3 |
4 | packages:
5 | - .
6 |
7 | extra-deps:
8 | - silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
9 | - WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
10 | - derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
11 | - vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
--------------------------------------------------------------------------------
/examples/haskell/stack.yaml.lock:
--------------------------------------------------------------------------------
1 | # This file was autogenerated by Stack.
2 | # You should not edit this file by hand.
3 | # For more information, please see the documentation at:
4 | # https://docs.haskellstack.org/en/stable/lock_files
5 |
6 | packages:
7 | - completed:
8 | hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
9 | pantry-tree:
10 | sha256: a62e813f978d32c87769796fded981d25fcf2875bb2afdf60ed6279f931ccd7f
11 | size: 1391
12 | original:
13 | hackage: silero-vad-0.1.0.4@sha256:2bff95be978a2782915b250edc795760d4cf76838e37bb7d4a965dc32566eb0f,5476
14 | - completed:
15 | hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
16 | pantry-tree:
17 | sha256: ee5ccd70fa7fe6ffc360ebd762b2e3f44ae10406aa27f3842d55b8cbd1a19498
18 | size: 405
19 | original:
20 | hackage: WAVE-0.1.6@sha256:f744ff68f5e3a0d1f84fab373ea35970659085d213aef20860357512d0458c5c,1016
21 | - completed:
22 | hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
23 | pantry-tree:
24 | sha256: 48e35a72d1bb593173890616c8d7efd636a650a306a50bb3e1513e679939d27e
25 | size: 902
26 | original:
27 | hackage: derive-storable-0.3.1.0@sha256:bd1c51c155a00e2be18325d553d6764dd678904a85647d6ba952af998e70aa59,2313
28 | - completed:
29 | hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
30 | pantry-tree:
31 | sha256: 2176fd677a02a4c47337f7dca5aeca2745dbb821a6ea5c7099b3a991ecd7f4f0
32 | size: 4478
33 | original:
34 | hackage: vector-0.13.2.0@sha256:98f5cb3080a3487527476e3c272dcadaba1376539f2aa0646f2f19b3af6b2f67,8481
35 | snapshots:
36 | - completed:
37 | sha256: 5a59b2a405b3aba3c00188453be172b85893cab8ebc352b1ef58b0eae5d248a2
38 | size: 650475
39 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
40 | original:
41 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/lts/20/26.yaml
42 |
--------------------------------------------------------------------------------
/examples/java-example/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 |
5 | org.example
6 | java-example
7 | 1.0-SNAPSHOT
8 | jar
9 |
10 | sliero-vad-example
11 | http://maven.apache.org
12 |
13 |
14 | UTF-8
15 |
16 |
17 |
18 |
19 | junit
20 | junit
21 | 3.8.1
22 | test
23 |
24 |
25 | com.microsoft.onnxruntime
26 | onnxruntime
27 | 1.16.0-rc1
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/examples/java-example/src/main/java/org/example/App.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import javax.sound.sampled.*;
5 | import java.util.Map;
6 |
7 | public class App {
8 |
9 | private static final String MODEL_PATH = "src/main/resources/silero_vad.onnx";
10 | private static final int SAMPLE_RATE = 16000;
11 | private static final float START_THRESHOLD = 0.6f;
12 | private static final float END_THRESHOLD = 0.45f;
13 | private static final int MIN_SILENCE_DURATION_MS = 600;
14 | private static final int SPEECH_PAD_MS = 500;
15 | private static final int WINDOW_SIZE_SAMPLES = 2048;
16 |
17 | public static void main(String[] args) {
18 | // Initialize the Voice Activity Detector
19 | SlieroVadDetector vadDetector;
20 | try {
21 | vadDetector = new SlieroVadDetector(MODEL_PATH, START_THRESHOLD, END_THRESHOLD, SAMPLE_RATE, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
22 | } catch (OrtException e) {
23 | System.err.println("Error initializing the VAD detector: " + e.getMessage());
24 | return;
25 | }
26 |
27 | // Set audio format
28 | AudioFormat format = new AudioFormat(SAMPLE_RATE, 16, 1, true, false);
29 | DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
30 |
31 | // Get the target data line and open it with the specified format
32 | TargetDataLine targetDataLine;
33 | try {
34 | targetDataLine = (TargetDataLine) AudioSystem.getLine(info);
35 | targetDataLine.open(format);
36 | targetDataLine.start();
37 | } catch (LineUnavailableException e) {
38 | System.err.println("Error opening target data line: " + e.getMessage());
39 | return;
40 | }
41 |
42 | // Main loop to continuously read data and apply Voice Activity Detection
43 | while (targetDataLine.isOpen()) {
44 | byte[] data = new byte[WINDOW_SIZE_SAMPLES];
45 |
46 | int numBytesRead = targetDataLine.read(data, 0, data.length);
47 | if (numBytesRead <= 0) {
48 | System.err.println("Error reading data from target data line.");
49 | continue;
50 | }
51 |
52 | // Apply the Voice Activity Detector to the data and get the result
53 | Map detectResult;
54 | try {
55 | detectResult = vadDetector.apply(data, true);
56 | } catch (Exception e) {
57 | System.err.println("Error applying VAD detector: " + e.getMessage());
58 | continue;
59 | }
60 |
61 | if (!detectResult.isEmpty()) {
62 | System.out.println(detectResult);
63 | }
64 | }
65 |
66 | // Close the target data line to release audio resources
67 | targetDataLine.close();
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/examples/java-example/src/main/java/org/example/SlieroVadDetector.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 | import ai.onnxruntime.OrtException;
4 |
5 | import java.math.BigDecimal;
6 | import java.math.RoundingMode;
7 | import java.util.Collections;
8 | import java.util.HashMap;
9 | import java.util.Map;
10 |
11 |
12 | public class SlieroVadDetector {
13 | // OnnxModel model used for speech processing
14 | private final SlieroVadOnnxModel model;
15 | // Threshold for speech start
16 | private final float startThreshold;
17 | // Threshold for speech end
18 | private final float endThreshold;
19 | // Sampling rate
20 | private final int samplingRate;
21 | // Minimum number of silence samples to determine the end threshold of speech
22 | private final float minSilenceSamples;
23 | // Additional number of samples for speech start or end to calculate speech start or end time
24 | private final float speechPadSamples;
25 | // Whether in the triggered state (i.e. whether speech is being detected)
26 | private boolean triggered;
27 | // Temporarily stored number of speech end samples
28 | private int tempEnd;
29 | // Number of samples currently being processed
30 | private int currentSample;
31 |
32 |
33 | public SlieroVadDetector(String modelPath,
34 | float startThreshold,
35 | float endThreshold,
36 | int samplingRate,
37 | int minSilenceDurationMs,
38 | int speechPadMs) throws OrtException {
39 | // Check if the sampling rate is 8000 or 16000, if not, throw an exception
40 | if (samplingRate != 8000 && samplingRate != 16000) {
41 | throw new IllegalArgumentException("does not support sampling rates other than [8000, 16000]");
42 | }
43 |
44 | // Initialize the parameters
45 | this.model = new SlieroVadOnnxModel(modelPath);
46 | this.startThreshold = startThreshold;
47 | this.endThreshold = endThreshold;
48 | this.samplingRate = samplingRate;
49 | this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
50 | this.speechPadSamples = samplingRate * speechPadMs / 1000f;
51 | // Reset the state
52 | reset();
53 | }
54 |
55 | // Method to reset the state, including the model state, trigger state, temporary end time, and current sample count
56 | public void reset() {
57 | model.resetStates();
58 | triggered = false;
59 | tempEnd = 0;
60 | currentSample = 0;
61 | }
62 |
63 | // apply method for processing the audio array, returning possible speech start or end times
64 | public Map apply(byte[] data, boolean returnSeconds) {
65 |
66 | // Convert the byte array to a float array
67 | float[] audioData = new float[data.length / 2];
68 | for (int i = 0; i < audioData.length; i++) {
69 | audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
70 | }
71 |
72 | // Get the length of the audio array as the window size
73 | int windowSizeSamples = audioData.length;
74 | // Update the current sample count
75 | currentSample += windowSizeSamples;
76 |
77 | // Call the model to get the prediction probability of speech
78 | float speechProb = 0;
79 | try {
80 | speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
81 | } catch (OrtException e) {
82 | throw new RuntimeException(e);
83 | }
84 |
85 | // If the speech probability is greater than the threshold and the temporary end time is not 0, reset the temporary end time
86 | // This indicates that the speech duration has exceeded expectations and needs to recalculate the end time
87 | if (speechProb >= startThreshold && tempEnd != 0) {
88 | tempEnd = 0;
89 | }
90 |
91 | // If the speech probability is greater than the threshold and not in the triggered state, set to triggered state and calculate the speech start time
92 | if (speechProb >= startThreshold && !triggered) {
93 | triggered = true;
94 | int speechStart = (int) (currentSample - speechPadSamples);
95 | speechStart = Math.max(speechStart, 0);
96 | Map result = new HashMap<>();
97 | // Decide whether to return the result in seconds or sample count based on the returnSeconds parameter
98 | if (returnSeconds) {
99 | double speechStartSeconds = speechStart / (double) samplingRate;
100 | double roundedSpeechStart = BigDecimal.valueOf(speechStartSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
101 | result.put("start", roundedSpeechStart);
102 | } else {
103 | result.put("start", (double) speechStart);
104 | }
105 |
106 | return result;
107 | }
108 |
109 | // If the speech probability is less than a certain threshold and in the triggered state, calculate the speech end time
110 | if (speechProb < endThreshold && triggered) {
111 | // Initialize or update the temporary end time
112 | if (tempEnd == 0) {
113 | tempEnd = currentSample;
114 | }
115 | // If the number of silence samples between the current sample and the temporary end time is less than the minimum silence samples, return null
116 | // This indicates that it is not yet possible to determine whether the speech has ended
117 | if (currentSample - tempEnd < minSilenceSamples) {
118 | return Collections.emptyMap();
119 | } else {
120 | // Calculate the speech end time, reset the trigger state and temporary end time
121 | int speechEnd = (int) (tempEnd + speechPadSamples);
122 | tempEnd = 0;
123 | triggered = false;
124 | Map result = new HashMap<>();
125 |
126 | if (returnSeconds) {
127 | double speechEndSeconds = speechEnd / (double) samplingRate;
128 | double roundedSpeechEnd = BigDecimal.valueOf(speechEndSeconds).setScale(1, RoundingMode.HALF_UP).doubleValue();
129 | result.put("end", roundedSpeechEnd);
130 | } else {
131 | result.put("end", (double) speechEnd);
132 | }
133 | return result;
134 | }
135 | }
136 |
137 | // If the above conditions are not met, return null by default
138 | return Collections.emptyMap();
139 | }
140 |
141 | public void close() throws OrtException {
142 | reset();
143 | model.close();
144 | }
145 | }
146 |
--------------------------------------------------------------------------------
/examples/java-example/src/main/java/org/example/SlieroVadOnnxModel.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 | import ai.onnxruntime.OnnxTensor;
4 | import ai.onnxruntime.OrtEnvironment;
5 | import ai.onnxruntime.OrtException;
6 | import ai.onnxruntime.OrtSession;
7 | import java.util.Arrays;
8 | import java.util.HashMap;
9 | import java.util.List;
10 | import java.util.Map;
11 |
12 | public class SlieroVadOnnxModel {
13 | // Define private variable OrtSession
14 | private final OrtSession session;
15 | private float[][][] h;
16 | private float[][][] c;
17 | // Define the last sample rate
18 | private int lastSr = 0;
19 | // Define the last batch size
20 | private int lastBatchSize = 0;
21 | // Define a list of supported sample rates
22 | private static final List SAMPLE_RATES = Arrays.asList(8000, 16000);
23 |
24 | // Constructor
25 | public SlieroVadOnnxModel(String modelPath) throws OrtException {
26 | // Get the ONNX runtime environment
27 | OrtEnvironment env = OrtEnvironment.getEnvironment();
28 | // Create an ONNX session options object
29 | OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
30 | // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
31 | opts.setInterOpNumThreads(1);
32 | // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
33 | opts.setIntraOpNumThreads(1);
34 | // Add a CPU device, setting to false disables CPU execution optimization
35 | opts.addCPU(true);
36 | // Create an ONNX session using the environment, model path, and options
37 | session = env.createSession(modelPath, opts);
38 | // Reset states
39 | resetStates();
40 | }
41 |
42 | /**
43 | * Reset states
44 | */
45 | void resetStates() {
46 | h = new float[2][1][64];
47 | c = new float[2][1][64];
48 | lastSr = 0;
49 | lastBatchSize = 0;
50 | }
51 |
52 | public void close() throws OrtException {
53 | session.close();
54 | }
55 |
56 | /**
57 | * Define inner class ValidationResult
58 | */
59 | public static class ValidationResult {
60 | public final float[][] x;
61 | public final int sr;
62 |
63 | // Constructor
64 | public ValidationResult(float[][] x, int sr) {
65 | this.x = x;
66 | this.sr = sr;
67 | }
68 | }
69 |
70 | /**
71 | * Function to validate input data
72 | */
73 | private ValidationResult validateInput(float[][] x, int sr) {
74 | // Process the input data with dimension 1
75 | if (x.length == 1) {
76 | x = new float[][]{x[0]};
77 | }
78 | // Throw an exception when the input data dimension is greater than 2
79 | if (x.length > 2) {
80 | throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
81 | }
82 |
83 | // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
84 | if (sr != 16000 && (sr % 16000 == 0)) {
85 | int step = sr / 16000;
86 | float[][] reducedX = new float[x.length][];
87 |
88 | for (int i = 0; i < x.length; i++) {
89 | float[] current = x[i];
90 | float[] newArr = new float[(current.length + step - 1) / step];
91 |
92 | for (int j = 0, index = 0; j < current.length; j += step, index++) {
93 | newArr[index] = current[j];
94 | }
95 |
96 | reducedX[i] = newArr;
97 | }
98 |
99 | x = reducedX;
100 | sr = 16000;
101 | }
102 |
103 | // If the sample rate is not in the list of supported sample rates, throw an exception
104 | if (!SAMPLE_RATES.contains(sr)) {
105 | throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
106 | }
107 |
108 | // If the input audio block is too short, throw an exception
109 | if (((float) sr) / x[0].length > 31.25) {
110 | throw new IllegalArgumentException("Input audio is too short");
111 | }
112 |
113 | // Return the validated result
114 | return new ValidationResult(x, sr);
115 | }
116 |
117 | /**
118 | * Method to call the ONNX model
119 | */
120 | public float[] call(float[][] x, int sr) throws OrtException {
121 | ValidationResult result = validateInput(x, sr);
122 | x = result.x;
123 | sr = result.sr;
124 |
125 | int batchSize = x.length;
126 |
127 | if (lastBatchSize == 0 || lastSr != sr || lastBatchSize != batchSize) {
128 | resetStates();
129 | }
130 |
131 | OrtEnvironment env = OrtEnvironment.getEnvironment();
132 |
133 | OnnxTensor inputTensor = null;
134 | OnnxTensor hTensor = null;
135 | OnnxTensor cTensor = null;
136 | OnnxTensor srTensor = null;
137 | OrtSession.Result ortOutputs = null;
138 |
139 | try {
140 | // Create input tensors
141 | inputTensor = OnnxTensor.createTensor(env, x);
142 | hTensor = OnnxTensor.createTensor(env, h);
143 | cTensor = OnnxTensor.createTensor(env, c);
144 | srTensor = OnnxTensor.createTensor(env, new long[]{sr});
145 |
146 | Map inputs = new HashMap<>();
147 | inputs.put("input", inputTensor);
148 | inputs.put("sr", srTensor);
149 | inputs.put("h", hTensor);
150 | inputs.put("c", cTensor);
151 |
152 | // Call the ONNX model for calculation
153 | ortOutputs = session.run(inputs);
154 | // Get the output results
155 | float[][] output = (float[][]) ortOutputs.get(0).getValue();
156 | h = (float[][][]) ortOutputs.get(1).getValue();
157 | c = (float[][][]) ortOutputs.get(2).getValue();
158 |
159 | lastSr = sr;
160 | lastBatchSize = batchSize;
161 | return output[0];
162 | } finally {
163 | if (inputTensor != null) {
164 | inputTensor.close();
165 | }
166 | if (hTensor != null) {
167 | hTensor.close();
168 | }
169 | if (cTensor != null) {
170 | cTensor.close();
171 | }
172 | if (srTensor != null) {
173 | srTensor.close();
174 | }
175 | if (ortOutputs != null) {
176 | ortOutputs.close();
177 | }
178 | }
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/examples/java-wav-file-example/src/main/java/org/example/App.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 | import ai.onnxruntime.OrtException;
4 | import java.io.File;
5 | import java.util.List;
6 |
7 | public class App {
8 |
9 | private static final String MODEL_PATH = "/path/silero_vad.onnx";
10 | private static final String EXAMPLE_WAV_FILE = "/path/example.wav";
11 | private static final int SAMPLE_RATE = 16000;
12 | private static final float THRESHOLD = 0.5f;
13 | private static final int MIN_SPEECH_DURATION_MS = 250;
14 | private static final float MAX_SPEECH_DURATION_SECONDS = Float.POSITIVE_INFINITY;
15 | private static final int MIN_SILENCE_DURATION_MS = 100;
16 | private static final int SPEECH_PAD_MS = 30;
17 |
18 | public static void main(String[] args) {
19 | // Initialize the Voice Activity Detector
20 | SileroVadDetector vadDetector;
21 | try {
22 | vadDetector = new SileroVadDetector(MODEL_PATH, THRESHOLD, SAMPLE_RATE,
23 | MIN_SPEECH_DURATION_MS, MAX_SPEECH_DURATION_SECONDS, MIN_SILENCE_DURATION_MS, SPEECH_PAD_MS);
24 | fromWavFile(vadDetector, new File(EXAMPLE_WAV_FILE));
25 | } catch (OrtException e) {
26 | System.err.println("Error initializing the VAD detector: " + e.getMessage());
27 | }
28 | }
29 |
30 | public static void fromWavFile(SileroVadDetector vadDetector, File wavFile) {
31 | List speechTimeList = vadDetector.getSpeechSegmentList(wavFile);
32 | for (SileroSpeechSegment speechSegment : speechTimeList) {
33 | System.out.println(String.format("start second: %f, end second: %f",
34 | speechSegment.getStartSecond(), speechSegment.getEndSecond()));
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/examples/java-wav-file-example/src/main/java/org/example/SileroSpeechSegment.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 |
4 | public class SileroSpeechSegment {
5 | private Integer startOffset;
6 | private Integer endOffset;
7 | private Float startSecond;
8 | private Float endSecond;
9 |
10 | public SileroSpeechSegment() {
11 | }
12 |
13 | public SileroSpeechSegment(Integer startOffset, Integer endOffset, Float startSecond, Float endSecond) {
14 | this.startOffset = startOffset;
15 | this.endOffset = endOffset;
16 | this.startSecond = startSecond;
17 | this.endSecond = endSecond;
18 | }
19 |
20 | public Integer getStartOffset() {
21 | return startOffset;
22 | }
23 |
24 | public Integer getEndOffset() {
25 | return endOffset;
26 | }
27 |
28 | public Float getStartSecond() {
29 | return startSecond;
30 | }
31 |
32 | public Float getEndSecond() {
33 | return endSecond;
34 | }
35 |
36 | public void setStartOffset(Integer startOffset) {
37 | this.startOffset = startOffset;
38 | }
39 |
40 | public void setEndOffset(Integer endOffset) {
41 | this.endOffset = endOffset;
42 | }
43 |
44 | public void setStartSecond(Float startSecond) {
45 | this.startSecond = startSecond;
46 | }
47 |
48 | public void setEndSecond(Float endSecond) {
49 | this.endSecond = endSecond;
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/examples/java-wav-file-example/src/main/java/org/example/SileroVadDetector.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 |
4 | import ai.onnxruntime.OrtException;
5 |
6 | import javax.sound.sampled.AudioInputStream;
7 | import javax.sound.sampled.AudioSystem;
8 | import java.io.File;
9 | import java.util.ArrayList;
10 | import java.util.Comparator;
11 | import java.util.List;
12 |
13 | public class SileroVadDetector {
14 | private final SileroVadOnnxModel model;
15 | private final float threshold;
16 | private final float negThreshold;
17 | private final int samplingRate;
18 | private final int windowSizeSample;
19 | private final float minSpeechSamples;
20 | private final float speechPadSamples;
21 | private final float maxSpeechSamples;
22 | private final float minSilenceSamples;
23 | private final float minSilenceSamplesAtMaxSpeech;
24 | private int audioLengthSamples;
25 | private static final float THRESHOLD_GAP = 0.15f;
26 | private static final Integer SAMPLING_RATE_8K = 8000;
27 | private static final Integer SAMPLING_RATE_16K = 16000;
28 |
29 | /**
30 | * Constructor
31 | * @param onnxModelPath the path of silero-vad onnx model
32 | * @param threshold threshold for speech start
33 | * @param samplingRate audio sampling rate, only available for [8k, 16k]
34 | * @param minSpeechDurationMs Minimum speech length in millis, any speech duration that smaller than this value would not be considered as speech
35 | * @param maxSpeechDurationSeconds Maximum speech length in millis, recommend to be set as Float.POSITIVE_INFINITY
36 | * @param minSilenceDurationMs Minimum silence length in millis, any silence duration that smaller than this value would not be considered as silence
37 | * @param speechPadMs Additional pad millis for speech start and end
38 | * @throws OrtException
39 | */
40 | public SileroVadDetector(String onnxModelPath, float threshold, int samplingRate,
41 | int minSpeechDurationMs, float maxSpeechDurationSeconds,
42 | int minSilenceDurationMs, int speechPadMs) throws OrtException {
43 | if (samplingRate != SAMPLING_RATE_8K && samplingRate != SAMPLING_RATE_16K) {
44 | throw new IllegalArgumentException("Sampling rate not support, only available for [8000, 16000]");
45 | }
46 | this.model = new SileroVadOnnxModel(onnxModelPath);
47 | this.samplingRate = samplingRate;
48 | this.threshold = threshold;
49 | this.negThreshold = threshold - THRESHOLD_GAP;
50 | if (samplingRate == SAMPLING_RATE_16K) {
51 | this.windowSizeSample = 512;
52 | } else {
53 | this.windowSizeSample = 256;
54 | }
55 | this.minSpeechSamples = samplingRate * minSpeechDurationMs / 1000f;
56 | this.speechPadSamples = samplingRate * speechPadMs / 1000f;
57 | this.maxSpeechSamples = samplingRate * maxSpeechDurationSeconds - windowSizeSample - 2 * speechPadSamples;
58 | this.minSilenceSamples = samplingRate * minSilenceDurationMs / 1000f;
59 | this.minSilenceSamplesAtMaxSpeech = samplingRate * 98 / 1000f;
60 | this.reset();
61 | }
62 |
63 | /**
64 | * Method to reset the state
65 | */
66 | public void reset() {
67 | model.resetStates();
68 | }
69 |
70 | /**
71 | * Get speech segment list by given wav-format file
72 | * @param wavFile wav file
73 | * @return list of speech segment
74 | */
75 | public List getSpeechSegmentList(File wavFile) {
76 | reset();
77 | try (AudioInputStream audioInputStream = AudioSystem.getAudioInputStream(wavFile)){
78 | List speechProbList = new ArrayList<>();
79 | this.audioLengthSamples = audioInputStream.available() / 2;
80 | byte[] data = new byte[this.windowSizeSample * 2];
81 | int numBytesRead = 0;
82 |
83 | while ((numBytesRead = audioInputStream.read(data)) != -1) {
84 | if (numBytesRead <= 0) {
85 | break;
86 | }
87 | // Convert the byte array to a float array
88 | float[] audioData = new float[data.length / 2];
89 | for (int i = 0; i < audioData.length; i++) {
90 | audioData[i] = ((data[i * 2] & 0xff) | (data[i * 2 + 1] << 8)) / 32767.0f;
91 | }
92 |
93 | float speechProb = 0;
94 | try {
95 | speechProb = model.call(new float[][]{audioData}, samplingRate)[0];
96 | speechProbList.add(speechProb);
97 | } catch (OrtException e) {
98 | throw e;
99 | }
100 | }
101 | return calculateProb(speechProbList);
102 | } catch (Exception e) {
103 | throw new RuntimeException("SileroVadDetector getSpeechTimeList with error", e);
104 | }
105 | }
106 |
107 | /**
108 | * Calculate speech segement by probability
109 | * @param speechProbList speech probability list
110 | * @return list of speech segment
111 | */
112 | private List calculateProb(List speechProbList) {
113 | List result = new ArrayList<>();
114 | boolean triggered = false;
115 | int tempEnd = 0, prevEnd = 0, nextStart = 0;
116 | SileroSpeechSegment segment = new SileroSpeechSegment();
117 |
118 | for (int i = 0; i < speechProbList.size(); i++) {
119 | Float speechProb = speechProbList.get(i);
120 | if (speechProb >= threshold && (tempEnd != 0)) {
121 | tempEnd = 0;
122 | if (nextStart < prevEnd) {
123 | nextStart = windowSizeSample * i;
124 | }
125 | }
126 |
127 | if (speechProb >= threshold && !triggered) {
128 | triggered = true;
129 | segment.setStartOffset(windowSizeSample * i);
130 | continue;
131 | }
132 |
133 | if (triggered && (windowSizeSample * i) - segment.getStartOffset() > maxSpeechSamples) {
134 | if (prevEnd != 0) {
135 | segment.setEndOffset(prevEnd);
136 | result.add(segment);
137 | segment = new SileroSpeechSegment();
138 | if (nextStart < prevEnd) {
139 | triggered = false;
140 | }else {
141 | segment.setStartOffset(nextStart);
142 | }
143 | prevEnd = 0;
144 | nextStart = 0;
145 | tempEnd = 0;
146 | }else {
147 | segment.setEndOffset(windowSizeSample * i);
148 | result.add(segment);
149 | segment = new SileroSpeechSegment();
150 | prevEnd = 0;
151 | nextStart = 0;
152 | tempEnd = 0;
153 | triggered = false;
154 | continue;
155 | }
156 | }
157 |
158 | if (speechProb < negThreshold && triggered) {
159 | if (tempEnd == 0) {
160 | tempEnd = windowSizeSample * i;
161 | }
162 | if (((windowSizeSample * i) - tempEnd) > minSilenceSamplesAtMaxSpeech) {
163 | prevEnd = tempEnd;
164 | }
165 | if ((windowSizeSample * i) - tempEnd < minSilenceSamples) {
166 | continue;
167 | }else {
168 | segment.setEndOffset(tempEnd);
169 | if ((segment.getEndOffset() - segment.getStartOffset()) > minSpeechSamples) {
170 | result.add(segment);
171 | }
172 | segment = new SileroSpeechSegment();
173 | prevEnd = 0;
174 | nextStart = 0;
175 | tempEnd = 0;
176 | triggered = false;
177 | continue;
178 | }
179 | }
180 | }
181 |
182 | if (segment.getStartOffset() != null && (audioLengthSamples - segment.getStartOffset()) > minSpeechSamples) {
183 | segment.setEndOffset(audioLengthSamples);
184 | result.add(segment);
185 | }
186 |
187 | for (int i = 0; i < result.size(); i++) {
188 | SileroSpeechSegment item = result.get(i);
189 | if (i == 0) {
190 | item.setStartOffset((int)(Math.max(0,item.getStartOffset() - speechPadSamples)));
191 | }
192 | if (i != result.size() - 1) {
193 | SileroSpeechSegment nextItem = result.get(i + 1);
194 | Integer silenceDuration = nextItem.getStartOffset() - item.getEndOffset();
195 | if(silenceDuration < 2 * speechPadSamples){
196 | item.setEndOffset(item.getEndOffset() + (silenceDuration / 2 ));
197 | nextItem.setStartOffset(Math.max(0, nextItem.getStartOffset() - (silenceDuration / 2)));
198 | } else {
199 | item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
200 | nextItem.setStartOffset((int)(Math.max(0,nextItem.getStartOffset() - speechPadSamples)));
201 | }
202 | }else {
203 | item.setEndOffset((int)(Math.min(audioLengthSamples, item.getEndOffset() + speechPadSamples)));
204 | }
205 | }
206 |
207 | return mergeListAndCalculateSecond(result, samplingRate);
208 | }
209 |
210 | private List mergeListAndCalculateSecond(List original, Integer samplingRate) {
211 | List result = new ArrayList<>();
212 | if (original == null || original.size() == 0) {
213 | return result;
214 | }
215 | Integer left = original.get(0).getStartOffset();
216 | Integer right = original.get(0).getEndOffset();
217 | if (original.size() > 1) {
218 | original.sort(Comparator.comparingLong(SileroSpeechSegment::getStartOffset));
219 | for (int i = 1; i < original.size(); i++) {
220 | SileroSpeechSegment segment = original.get(i);
221 |
222 | if (segment.getStartOffset() > right) {
223 | result.add(new SileroSpeechSegment(left, right,
224 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
225 | left = segment.getStartOffset();
226 | right = segment.getEndOffset();
227 | } else {
228 | right = Math.max(right, segment.getEndOffset());
229 | }
230 | }
231 | result.add(new SileroSpeechSegment(left, right,
232 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
233 | }else {
234 | result.add(new SileroSpeechSegment(left, right,
235 | calculateSecondByOffset(left, samplingRate), calculateSecondByOffset(right, samplingRate)));
236 | }
237 | return result;
238 | }
239 |
240 | private Float calculateSecondByOffset(Integer offset, Integer samplingRate) {
241 | float secondValue = offset * 1.0f / samplingRate;
242 | return (float) Math.floor(secondValue * 1000.0f) / 1000.0f;
243 | }
244 | }
245 |
--------------------------------------------------------------------------------
/examples/java-wav-file-example/src/main/java/org/example/SileroVadOnnxModel.java:
--------------------------------------------------------------------------------
1 | package org.example;
2 |
3 | import ai.onnxruntime.OnnxTensor;
4 | import ai.onnxruntime.OrtEnvironment;
5 | import ai.onnxruntime.OrtException;
6 | import ai.onnxruntime.OrtSession;
7 | import java.util.Arrays;
8 | import java.util.HashMap;
9 | import java.util.List;
10 | import java.util.Map;
11 |
12 | public class SileroVadOnnxModel {
13 | // Define private variable OrtSession
14 | private final OrtSession session;
15 | private float[][][] state;
16 | private float[][] context;
17 | // Define the last sample rate
18 | private int lastSr = 0;
19 | // Define the last batch size
20 | private int lastBatchSize = 0;
21 | // Define a list of supported sample rates
22 | private static final List SAMPLE_RATES = Arrays.asList(8000, 16000);
23 |
24 | // Constructor
25 | public SileroVadOnnxModel(String modelPath) throws OrtException {
26 | // Get the ONNX runtime environment
27 | OrtEnvironment env = OrtEnvironment.getEnvironment();
28 | // Create an ONNX session options object
29 | OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
30 | // Set the InterOp thread count to 1, InterOp threads are used for parallel processing of different computation graph operations
31 | opts.setInterOpNumThreads(1);
32 | // Set the IntraOp thread count to 1, IntraOp threads are used for parallel processing within a single operation
33 | opts.setIntraOpNumThreads(1);
34 | // Add a CPU device, setting to false disables CPU execution optimization
35 | opts.addCPU(true);
36 | // Create an ONNX session using the environment, model path, and options
37 | session = env.createSession(modelPath, opts);
38 | // Reset states
39 | resetStates();
40 | }
41 |
42 | /**
43 | * Reset states
44 | */
45 | void resetStates() {
46 | state = new float[2][1][128];
47 | context = new float[0][];
48 | lastSr = 0;
49 | lastBatchSize = 0;
50 | }
51 |
52 | public void close() throws OrtException {
53 | session.close();
54 | }
55 |
56 | /**
57 | * Define inner class ValidationResult
58 | */
59 | public static class ValidationResult {
60 | public final float[][] x;
61 | public final int sr;
62 |
63 | // Constructor
64 | public ValidationResult(float[][] x, int sr) {
65 | this.x = x;
66 | this.sr = sr;
67 | }
68 | }
69 |
70 | /**
71 | * Function to validate input data
72 | */
73 | private ValidationResult validateInput(float[][] x, int sr) {
74 | // Process the input data with dimension 1
75 | if (x.length == 1) {
76 | x = new float[][]{x[0]};
77 | }
78 | // Throw an exception when the input data dimension is greater than 2
79 | if (x.length > 2) {
80 | throw new IllegalArgumentException("Incorrect audio data dimension: " + x[0].length);
81 | }
82 |
83 | // Process the input data when the sample rate is not equal to 16000 and is a multiple of 16000
84 | if (sr != 16000 && (sr % 16000 == 0)) {
85 | int step = sr / 16000;
86 | float[][] reducedX = new float[x.length][];
87 |
88 | for (int i = 0; i < x.length; i++) {
89 | float[] current = x[i];
90 | float[] newArr = new float[(current.length + step - 1) / step];
91 |
92 | for (int j = 0, index = 0; j < current.length; j += step, index++) {
93 | newArr[index] = current[j];
94 | }
95 |
96 | reducedX[i] = newArr;
97 | }
98 |
99 | x = reducedX;
100 | sr = 16000;
101 | }
102 |
103 | // If the sample rate is not in the list of supported sample rates, throw an exception
104 | if (!SAMPLE_RATES.contains(sr)) {
105 | throw new IllegalArgumentException("Only supports sample rates " + SAMPLE_RATES + " (or multiples of 16000)");
106 | }
107 |
108 | // If the input audio block is too short, throw an exception
109 | if (((float) sr) / x[0].length > 31.25) {
110 | throw new IllegalArgumentException("Input audio is too short");
111 | }
112 |
113 | // Return the validated result
114 | return new ValidationResult(x, sr);
115 | }
116 |
117 | private static float[][] concatenate(float[][] a, float[][] b) {
118 | if (a.length != b.length) {
119 | throw new IllegalArgumentException("The number of rows in both arrays must be the same.");
120 | }
121 |
122 | int rows = a.length;
123 | int colsA = a[0].length;
124 | int colsB = b[0].length;
125 | float[][] result = new float[rows][colsA + colsB];
126 |
127 | for (int i = 0; i < rows; i++) {
128 | System.arraycopy(a[i], 0, result[i], 0, colsA);
129 | System.arraycopy(b[i], 0, result[i], colsA, colsB);
130 | }
131 |
132 | return result;
133 | }
134 |
135 | private static float[][] getLastColumns(float[][] array, int contextSize) {
136 | int rows = array.length;
137 | int cols = array[0].length;
138 |
139 | if (contextSize > cols) {
140 | throw new IllegalArgumentException("contextSize cannot be greater than the number of columns in the array.");
141 | }
142 |
143 | float[][] result = new float[rows][contextSize];
144 |
145 | for (int i = 0; i < rows; i++) {
146 | System.arraycopy(array[i], cols - contextSize, result[i], 0, contextSize);
147 | }
148 |
149 | return result;
150 | }
151 |
152 | /**
153 | * Method to call the ONNX model
154 | */
155 | public float[] call(float[][] x, int sr) throws OrtException {
156 | ValidationResult result = validateInput(x, sr);
157 | x = result.x;
158 | sr = result.sr;
159 | int numberSamples = 256;
160 | if (sr == 16000) {
161 | numberSamples = 512;
162 | }
163 |
164 | if (x[0].length != numberSamples) {
165 | throw new IllegalArgumentException("Provided number of samples is " + x[0].length + " (Supported values: 256 for 8000 sample rate, 512 for 16000)");
166 | }
167 |
168 | int batchSize = x.length;
169 |
170 | int contextSize = 32;
171 | if (sr == 16000) {
172 | contextSize = 64;
173 | }
174 |
175 | if (lastBatchSize == 0) {
176 | resetStates();
177 | }
178 | if (lastSr != 0 && lastSr != sr) {
179 | resetStates();
180 | }
181 | if (lastBatchSize != 0 && lastBatchSize != batchSize) {
182 | resetStates();
183 | }
184 |
185 | if (context.length == 0) {
186 | context = new float[batchSize][contextSize];
187 | }
188 |
189 | x = concatenate(context, x);
190 |
191 | OrtEnvironment env = OrtEnvironment.getEnvironment();
192 |
193 | OnnxTensor inputTensor = null;
194 | OnnxTensor stateTensor = null;
195 | OnnxTensor srTensor = null;
196 | OrtSession.Result ortOutputs = null;
197 |
198 | try {
199 | // Create input tensors
200 | inputTensor = OnnxTensor.createTensor(env, x);
201 | stateTensor = OnnxTensor.createTensor(env, state);
202 | srTensor = OnnxTensor.createTensor(env, new long[]{sr});
203 |
204 | Map inputs = new HashMap<>();
205 | inputs.put("input", inputTensor);
206 | inputs.put("sr", srTensor);
207 | inputs.put("state", stateTensor);
208 |
209 | // Call the ONNX model for calculation
210 | ortOutputs = session.run(inputs);
211 | // Get the output results
212 | float[][] output = (float[][]) ortOutputs.get(0).getValue();
213 | state = (float[][][]) ortOutputs.get(1).getValue();
214 |
215 | context = getLastColumns(x, contextSize);
216 | lastSr = sr;
217 | lastBatchSize = batchSize;
218 | return output[0];
219 | } finally {
220 | if (inputTensor != null) {
221 | inputTensor.close();
222 | }
223 | if (stateTensor != null) {
224 | stateTensor.close();
225 | }
226 | if (srTensor != null) {
227 | srTensor.close();
228 | }
229 | if (ortOutputs != null) {
230 | ortOutputs.close();
231 | }
232 | }
233 | }
234 | }
235 |
--------------------------------------------------------------------------------
/examples/microphone_and_webRTC_integration/README.md:
--------------------------------------------------------------------------------
1 |
2 | In this example, an integration with the microphone and the webRTC VAD has been done. I used [this](https://github.com/mozilla/DeepSpeech-examples/tree/r0.8/mic_vad_streaming) as a draft.
3 | Here a short video to present the results:
4 |
5 | https://user-images.githubusercontent.com/28188499/116685087-182ff100-a9b2-11eb-927d-ed9f621226ee.mp4
6 |
7 | # Requirements:
8 | The libraries used for the following example are:
9 | ```
10 | Python == 3.6.9
11 | webrtcvad >= 2.0.10
12 | torchaudio >= 0.8.1
13 | torch >= 1.8.1
14 | halo >= 0.0.31
15 | Soundfile >= 0.13.3
16 | ```
17 | Using pip3:
18 | ```
19 | pip3 install webrtcvad
20 | pip3 install torchaudio
21 | pip3 install torch
22 | pip3 install halo
23 | pip3 install soundfile
24 | ```
25 | Moreover, to make the code easier, the default sample_rate is 16KHz without resampling.
26 |
27 | This example has been tested on ``` ubuntu 18.04.3 LTS```
28 |
29 |
--------------------------------------------------------------------------------
/examples/microphone_and_webRTC_integration/microphone_and_webRTC_integration.py:
--------------------------------------------------------------------------------
1 | import collections, queue
2 | import numpy as np
3 | import pyaudio
4 | import webrtcvad
5 | from halo import Halo
6 | import torch
7 | import torchaudio
8 |
9 | class Audio(object):
10 | """Streams raw audio from microphone. Data is received in a separate thread, and stored in a buffer, to be read from."""
11 |
12 | FORMAT = pyaudio.paInt16
13 | # Network/VAD rate-space
14 | RATE_PROCESS = 16000
15 | CHANNELS = 1
16 | BLOCKS_PER_SECOND = 50
17 |
18 | def __init__(self, callback=None, device=None, input_rate=RATE_PROCESS):
19 | def proxy_callback(in_data, frame_count, time_info, status):
20 | #pylint: disable=unused-argument
21 | callback(in_data)
22 | return (None, pyaudio.paContinue)
23 | if callback is None: callback = lambda in_data: self.buffer_queue.put(in_data)
24 | self.buffer_queue = queue.Queue()
25 | self.device = device
26 | self.input_rate = input_rate
27 | self.sample_rate = self.RATE_PROCESS
28 | self.block_size = int(self.RATE_PROCESS / float(self.BLOCKS_PER_SECOND))
29 | self.block_size_input = int(self.input_rate / float(self.BLOCKS_PER_SECOND))
30 | self.pa = pyaudio.PyAudio()
31 |
32 | kwargs = {
33 | 'format': self.FORMAT,
34 | 'channels': self.CHANNELS,
35 | 'rate': self.input_rate,
36 | 'input': True,
37 | 'frames_per_buffer': self.block_size_input,
38 | 'stream_callback': proxy_callback,
39 | }
40 |
41 | self.chunk = None
42 | # if not default device
43 | if self.device:
44 | kwargs['input_device_index'] = self.device
45 |
46 | self.stream = self.pa.open(**kwargs)
47 | self.stream.start_stream()
48 |
49 | def read(self):
50 | """Return a block of audio data, blocking if necessary."""
51 | return self.buffer_queue.get()
52 |
53 | def destroy(self):
54 | self.stream.stop_stream()
55 | self.stream.close()
56 | self.pa.terminate()
57 |
58 | frame_duration_ms = property(lambda self: 1000 * self.block_size // self.sample_rate)
59 |
60 |
61 | class VADAudio(Audio):
62 | """Filter & segment audio with voice activity detection."""
63 |
64 | def __init__(self, aggressiveness=3, device=None, input_rate=None):
65 | super().__init__(device=device, input_rate=input_rate)
66 | self.vad = webrtcvad.Vad(aggressiveness)
67 |
68 | def frame_generator(self):
69 | """Generator that yields all audio frames from microphone."""
70 | if self.input_rate == self.RATE_PROCESS:
71 | while True:
72 | yield self.read()
73 | else:
74 | raise Exception("Resampling required")
75 |
76 | def vad_collector(self, padding_ms=300, ratio=0.75, frames=None):
77 | """Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
78 | Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
79 | Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
80 | |---utterence---| |---utterence---|
81 | """
82 | if frames is None: frames = self.frame_generator()
83 | num_padding_frames = padding_ms // self.frame_duration_ms
84 | ring_buffer = collections.deque(maxlen=num_padding_frames)
85 | triggered = False
86 |
87 | for frame in frames:
88 | if len(frame) < 640:
89 | return
90 |
91 | is_speech = self.vad.is_speech(frame, self.sample_rate)
92 |
93 | if not triggered:
94 | ring_buffer.append((frame, is_speech))
95 | num_voiced = len([f for f, speech in ring_buffer if speech])
96 | if num_voiced > ratio * ring_buffer.maxlen:
97 | triggered = True
98 | for f, s in ring_buffer:
99 | yield f
100 | ring_buffer.clear()
101 |
102 | else:
103 | yield frame
104 | ring_buffer.append((frame, is_speech))
105 | num_unvoiced = len([f for f, speech in ring_buffer if not speech])
106 | if num_unvoiced > ratio * ring_buffer.maxlen:
107 | triggered = False
108 | yield None
109 | ring_buffer.clear()
110 |
111 | def main(ARGS):
112 | # Start audio with VAD
113 | vad_audio = VADAudio(aggressiveness=ARGS.webRTC_aggressiveness,
114 | device=ARGS.device,
115 | input_rate=ARGS.rate)
116 |
117 | print("Listening (ctrl-C to exit)...")
118 | frames = vad_audio.vad_collector()
119 |
120 | # load silero VAD
121 | torchaudio.set_audio_backend("soundfile")
122 | model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
123 | model=ARGS.silaro_model_name,
124 | force_reload= ARGS.reload)
125 | (get_speech_ts,_,_, _,_, _, _) = utils
126 |
127 |
128 | # Stream from microphone to DeepSpeech using VAD
129 | spinner = None
130 | if not ARGS.nospinner:
131 | spinner = Halo(spinner='line')
132 | wav_data = bytearray()
133 | for frame in frames:
134 | if frame is not None:
135 | if spinner: spinner.start()
136 |
137 | wav_data.extend(frame)
138 | else:
139 | if spinner: spinner.stop()
140 | print("webRTC has detected a possible speech")
141 |
142 | newsound= np.frombuffer(wav_data,np.int16)
143 | audio_float32=Int2Float(newsound)
144 | time_stamps =get_speech_ts(audio_float32, model,num_steps=ARGS.num_steps,trig_sum=ARGS.trig_sum,neg_trig_sum=ARGS.neg_trig_sum,
145 | num_samples_per_window=ARGS.num_samples_per_window,min_speech_samples=ARGS.min_speech_samples,
146 | min_silence_samples=ARGS.min_silence_samples)
147 |
148 | if(len(time_stamps)>0):
149 | print("silero VAD has detected a possible speech")
150 | else:
151 | print("silero VAD has detected a noise")
152 | print()
153 | wav_data = bytearray()
154 |
155 |
156 | def Int2Float(sound):
157 | _sound = np.copy(sound) #
158 | abs_max = np.abs(_sound).max()
159 | _sound = _sound.astype('float32')
160 | if abs_max > 0:
161 | _sound *= 1/abs_max
162 | audio_float32 = torch.from_numpy(_sound.squeeze())
163 | return audio_float32
164 |
165 | if __name__ == '__main__':
166 | DEFAULT_SAMPLE_RATE = 16000
167 |
168 | import argparse
169 | parser = argparse.ArgumentParser(description="Stream from microphone to webRTC and silero VAD")
170 |
171 | parser.add_argument('-v', '--webRTC_aggressiveness', type=int, default=3,
172 | help="Set aggressiveness of webRTC: an integer between 0 and 3, 0 being the least aggressive about filtering out non-speech, 3 the most aggressive. Default: 3")
173 | parser.add_argument('--nospinner', action='store_true',
174 | help="Disable spinner")
175 | parser.add_argument('-d', '--device', type=int, default=None,
176 | help="Device input index (Int) as listed by pyaudio.PyAudio.get_device_info_by_index(). If not provided, falls back to PyAudio.get_default_device().")
177 |
178 | parser.add_argument('-name', '--silaro_model_name', type=str, default="silero_vad",
179 | help="select the name of the model. You can select between 'silero_vad',''silero_vad_micro','silero_vad_micro_8k','silero_vad_mini','silero_vad_mini_8k'")
180 | parser.add_argument('--reload', action='store_true',help="download the last version of the silero vad")
181 |
182 | parser.add_argument('-ts', '--trig_sum', type=float, default=0.25,
183 | help="overlapping windows are used for each audio chunk, trig sum defines average probability among those windows for switching into triggered state (speech state)")
184 |
185 | parser.add_argument('-nts', '--neg_trig_sum', type=float, default=0.07,
186 | help="same as trig_sum, but for switching from triggered to non-triggered state (non-speech)")
187 |
188 | parser.add_argument('-N', '--num_steps', type=int, default=8,
189 | help="number of overlapping windows to split audio chunk into (we recommend 4 or 8)")
190 |
191 | parser.add_argument('-nspw', '--num_samples_per_window', type=int, default=4000,
192 | help="number of samples in each window, our models were trained using 4000 samples (250 ms) per window, so this is preferable value (lesser values reduce quality)")
193 |
194 | parser.add_argument('-msps', '--min_speech_samples', type=int, default=10000,
195 | help="minimum speech chunk duration in samples")
196 |
197 | parser.add_argument('-msis', '--min_silence_samples', type=int, default=500,
198 | help=" minimum silence duration in samples between to separate speech chunks")
199 | ARGS = parser.parse_args()
200 | ARGS.rate=DEFAULT_SAMPLE_RATE
201 | main(ARGS)
202 |
--------------------------------------------------------------------------------
/examples/parallel_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Install Dependencies"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "# !pip install -q torchaudio\n",
17 | "SAMPLING_RATE = 16000\n",
18 | "import torch\n",
19 | "from pprint import pprint\n",
20 | "import time\n",
21 | "import shutil\n",
22 | "\n",
23 | "torch.set_num_threads(1)\n",
24 | "NUM_PROCESS=4 # set to the number of CPU cores in the machine\n",
25 | "NUM_COPIES=8\n",
26 | "# download wav files, make multiple copies\n",
27 | "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', f\"en_example0.wav\")\n",
28 | "for idx in range(NUM_COPIES-1):\n",
29 | " shutil.copy(f\"en_example0.wav\", f\"en_example{idx+1}.wav\")"
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## Load VAD model from torch hub"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
46 | " model='silero_vad',\n",
47 | " force_reload=True,\n",
48 | " onnx=False)\n",
49 | "\n",
50 | "(get_speech_timestamps,\n",
51 | "save_audio,\n",
52 | "read_audio,\n",
53 | "VADIterator,\n",
54 | "collect_chunks) = utils"
55 | ]
56 | },
57 | {
58 | "cell_type": "markdown",
59 | "metadata": {},
60 | "source": [
61 | "## Define a vad process function"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "import multiprocessing\n",
71 | "\n",
72 | "vad_models = dict()\n",
73 | "\n",
74 | "def init_model(model):\n",
75 | " pid = multiprocessing.current_process().pid\n",
76 | " model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
77 | " model='silero_vad',\n",
78 | " force_reload=False,\n",
79 | " onnx=False)\n",
80 | " vad_models[pid] = model\n",
81 | "\n",
82 | "def vad_process(audio_file: str):\n",
83 | " \n",
84 | " pid = multiprocessing.current_process().pid\n",
85 | " \n",
86 | " with torch.no_grad():\n",
87 | " wav = read_audio(audio_file, sampling_rate=SAMPLING_RATE)\n",
88 | " return get_speech_timestamps(\n",
89 | " wav,\n",
90 | " vad_models[pid],\n",
91 | " 0.46, # speech prob threshold\n",
92 | " 16000, # sample rate\n",
93 | " 300, # min speech duration in ms\n",
94 | " 20, # max speech duration in seconds\n",
95 | " 600, # min silence duration\n",
96 | " 512, # window size\n",
97 | " 200, # spech pad ms\n",
98 | " )"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "## Parallelization"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "from concurrent.futures import ProcessPoolExecutor, as_completed\n",
115 | "\n",
116 | "futures = []\n",
117 | "\n",
118 | "with ProcessPoolExecutor(max_workers=NUM_PROCESS, initializer=init_model, initargs=(model,)) as ex:\n",
119 | " for i in range(NUM_COPIES):\n",
120 | " futures.append(ex.submit(vad_process, f\"en_example{idx}.wav\"))\n",
121 | "\n",
122 | "for finished in as_completed(futures):\n",
123 | " pprint(finished.result())"
124 | ]
125 | }
126 | ],
127 | "metadata": {
128 | "kernelspec": {
129 | "display_name": "Python 3 (ipykernel)",
130 | "language": "python",
131 | "name": "python3"
132 | },
133 | "language_info": {
134 | "codemirror_mode": {
135 | "name": "ipython",
136 | "version": 3
137 | },
138 | "file_extension": ".py",
139 | "mimetype": "text/x-python",
140 | "name": "python",
141 | "nbconvert_exporter": "python",
142 | "pygments_lexer": "ipython3",
143 | "version": "3.10.14"
144 | },
145 | "toc": {
146 | "base_numbering": 1,
147 | "nav_menu": {},
148 | "number_sections": true,
149 | "sideBar": true,
150 | "skip_h1_title": false,
151 | "title_cell": "Table of Contents",
152 | "title_sidebar": "Contents",
153 | "toc_cell": false,
154 | "toc_position": {},
155 | "toc_section_display": true,
156 | "toc_window_display": false
157 | }
158 | },
159 | "nbformat": 4,
160 | "nbformat_minor": 2
161 | }
162 |
--------------------------------------------------------------------------------
/examples/pyaudio-streaming/README.md:
--------------------------------------------------------------------------------
1 | # Pyaudio Streaming Example
2 |
3 | This example notebook shows how micophone audio fetched by pyaudio can be processed with Silero-VAD.
4 |
5 | It has been designed as a low-level example for binary real-time streaming using only the prediction of the model, processing the binary data and plotting the speech probabilities at the end to visualize it.
6 |
7 | Currently, the notebook consits of two examples:
8 | - One that records audio of a predefined length from the microphone, process it with Silero-VAD, and plots it afterwards.
9 | - The other one plots the speech probabilities in real-time (using jupyterplot) and records the audio until you press enter.
10 |
11 | This example does not work in google colab! For local usage only.
12 |
13 | ## Example Video for the Real-Time Visualization
14 |
15 |
16 | https://user-images.githubusercontent.com/8079748/117580455-4622dd00-b0f8-11eb-858d-e6368ed4eada.mp4
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/examples/pyaudio-streaming/pyaudio-streaming-examples.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "76aa55ba",
6 | "metadata": {},
7 | "source": [
8 | "# Pyaudio Microphone Streaming Examples\n",
9 | "\n",
10 | "A simple notebook that uses pyaudio to get the microphone audio and feeds this audio then to Silero VAD.\n",
11 | "\n",
12 | "I created it as an example on how binary data from a stream could be feed into Silero VAD.\n",
13 | "\n",
14 | "\n",
15 | "Has been tested on Ubuntu 21.04 (x86). After you installed the dependencies below, no additional setup is required.\n",
16 | "\n",
17 | "This notebook does not work in google colab! For local usage only."
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "id": "4a4e15c2",
23 | "metadata": {},
24 | "source": [
25 | "## Dependencies\n",
26 | "The cell below lists all used dependencies and the used versions. Uncomment to install them from within the notebook."
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "id": "24205cce",
33 | "metadata": {
34 | "ExecuteTime": {
35 | "end_time": "2024-10-09T08:47:34.056898Z",
36 | "start_time": "2024-10-09T08:47:34.053418Z"
37 | }
38 | },
39 | "outputs": [],
40 | "source": [
41 | "#!pip install numpy>=1.24.0\n",
42 | "#!pip install torch>=1.12.0\n",
43 | "#!pip install matplotlib>=3.6.0\n",
44 | "#!pip install torchaudio>=0.12.0\n",
45 | "#!pip install soundfile==0.12.1\n",
46 | "#!apt install python3-pyaudio (linux) or pip install pyaudio (windows)"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "id": "cd22818f",
52 | "metadata": {},
53 | "source": [
54 | "## Imports"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 2,
60 | "id": "994d7f3a",
61 | "metadata": {
62 | "ExecuteTime": {
63 | "end_time": "2024-10-09T08:47:39.005032Z",
64 | "start_time": "2024-10-09T08:47:36.489952Z"
65 | }
66 | },
67 | "outputs": [
68 | {
69 | "ename": "ModuleNotFoundError",
70 | "evalue": "No module named 'pyaudio'",
71 | "output_type": "error",
72 | "traceback": [
73 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
74 | "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
75 | "Cell \u001b[0;32mIn[2], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpylab\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyaudio\u001b[39;00m\n",
76 | "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyaudio'"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "import io\n",
82 | "import numpy as np\n",
83 | "import torch\n",
84 | "torch.set_num_threads(1)\n",
85 | "import torchaudio\n",
86 | "import matplotlib\n",
87 | "import matplotlib.pylab as plt\n",
88 | "import pyaudio"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "ac5c52f7",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
99 | " model='silero_vad',\n",
100 | " force_reload=True)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "id": "ad5919dc",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "(get_speech_timestamps,\n",
111 | " save_audio,\n",
112 | " read_audio,\n",
113 | " VADIterator,\n",
114 | " collect_chunks) = utils"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "id": "784d1ab6",
120 | "metadata": {},
121 | "source": [
122 | "### Helper Methods"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "id": "af4bca64",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "# Taken from utils_vad.py\n",
133 | "def validate(model,\n",
134 | " inputs: torch.Tensor):\n",
135 | " with torch.no_grad():\n",
136 | " outs = model(inputs)\n",
137 | " return outs\n",
138 | "\n",
139 | "# Provided by Alexander Veysov\n",
140 | "def int2float(sound):\n",
141 | " abs_max = np.abs(sound).max()\n",
142 | " sound = sound.astype('float32')\n",
143 | " if abs_max > 0:\n",
144 | " sound *= 1/32768\n",
145 | " sound = sound.squeeze() # depends on the use case\n",
146 | " return sound"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "id": "ca13e514",
152 | "metadata": {},
153 | "source": [
154 | "## Pyaudio Set-up"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "id": "75f99022",
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "FORMAT = pyaudio.paInt16\n",
165 | "CHANNELS = 1\n",
166 | "SAMPLE_RATE = 16000\n",
167 | "CHUNK = int(SAMPLE_RATE / 10)\n",
168 | "\n",
169 | "audio = pyaudio.PyAudio()"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "id": "4da7d2ef",
175 | "metadata": {},
176 | "source": [
177 | "## Simple Example\n",
178 | "The following example reads the audio as 250ms chunks from the microphone, converts them to a Pytorch Tensor, and gets the probabilities/confidences if the model thinks the frame is voiced."
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "id": "6fe77661",
185 | "metadata": {},
186 | "outputs": [],
187 | "source": [
188 | "num_samples = 512"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "id": "23f4da3e",
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "stream = audio.open(format=FORMAT,\n",
199 | " channels=CHANNELS,\n",
200 | " rate=SAMPLE_RATE,\n",
201 | " input=True,\n",
202 | " frames_per_buffer=CHUNK)\n",
203 | "data = []\n",
204 | "voiced_confidences = []\n",
205 | "\n",
206 | "frames_to_record = 50\n",
207 | "\n",
208 | "print(\"Started Recording\")\n",
209 | "for i in range(0, frames_to_record):\n",
210 | " \n",
211 | " audio_chunk = stream.read(num_samples)\n",
212 | " \n",
213 | " # in case you want to save the audio later\n",
214 | " data.append(audio_chunk)\n",
215 | " \n",
216 | " audio_int16 = np.frombuffer(audio_chunk, np.int16);\n",
217 | "\n",
218 | " audio_float32 = int2float(audio_int16)\n",
219 | " \n",
220 | " # get the confidences and add them to the list to plot them later\n",
221 | " new_confidence = model(torch.from_numpy(audio_float32), 16000).item()\n",
222 | " voiced_confidences.append(new_confidence)\n",
223 | " \n",
224 | "print(\"Stopped the recording\")\n",
225 | "\n",
226 | "# plot the confidences for the speech\n",
227 | "plt.figure(figsize=(20,6))\n",
228 | "plt.plot(voiced_confidences)\n",
229 | "plt.show()"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "id": "fd243e8f",
235 | "metadata": {},
236 | "source": [
237 | "## Real Time Visualization\n",
238 | "\n",
239 | "As an enhancement to plot the speech probabilities in real time I added the implementation below.\n",
240 | "In contrast to the simeple one, it records the audio until to stop the recording by pressing enter.\n",
241 | "While looking into good ways to update matplotlib plots in real-time, I found a simple libarary that does the job. https://github.com/lvwerra/jupyterplot It has some limitations, but works for this use case really well.\n"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "id": "d36980c2",
248 | "metadata": {},
249 | "outputs": [],
250 | "source": [
251 | "#!pip install jupyterplot==0.0.3"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "id": "5607b616",
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "from jupyterplot import ProgressPlot\n",
262 | "import threading\n",
263 | "\n",
264 | "continue_recording = True\n",
265 | "\n",
266 | "def stop():\n",
267 | " input(\"Press Enter to stop the recording:\")\n",
268 | " global continue_recording\n",
269 | " continue_recording = False\n",
270 | "\n",
271 | "def start_recording():\n",
272 | " \n",
273 | " stream = audio.open(format=FORMAT,\n",
274 | " channels=CHANNELS,\n",
275 | " rate=SAMPLE_RATE,\n",
276 | " input=True,\n",
277 | " frames_per_buffer=CHUNK)\n",
278 | "\n",
279 | " data = []\n",
280 | " voiced_confidences = []\n",
281 | " \n",
282 | " global continue_recording\n",
283 | " continue_recording = True\n",
284 | " \n",
285 | " pp = ProgressPlot(plot_names=[\"Silero VAD\"],line_names=[\"speech probabilities\"], x_label=\"audio chunks\")\n",
286 | " \n",
287 | " stop_listener = threading.Thread(target=stop)\n",
288 | " stop_listener.start()\n",
289 | "\n",
290 | " while continue_recording:\n",
291 | " \n",
292 | " audio_chunk = stream.read(num_samples)\n",
293 | " \n",
294 | " # in case you want to save the audio later\n",
295 | " data.append(audio_chunk)\n",
296 | " \n",
297 | " audio_int16 = np.frombuffer(audio_chunk, np.int16);\n",
298 | "\n",
299 | " audio_float32 = int2float(audio_int16)\n",
300 | " \n",
301 | " # get the confidences and add them to the list to plot them later\n",
302 | " new_confidence = model(torch.from_numpy(audio_float32), 16000).item()\n",
303 | " voiced_confidences.append(new_confidence)\n",
304 | " \n",
305 | " pp.update(new_confidence)\n",
306 | "\n",
307 | "\n",
308 | " pp.finalize()"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "id": "dc4f0108",
315 | "metadata": {},
316 | "outputs": [],
317 | "source": [
318 | "start_recording()"
319 | ]
320 | }
321 | ],
322 | "metadata": {
323 | "kernelspec": {
324 | "display_name": "Python 3 (ipykernel)",
325 | "language": "python",
326 | "name": "python3"
327 | },
328 | "language_info": {
329 | "codemirror_mode": {
330 | "name": "ipython",
331 | "version": 3
332 | },
333 | "file_extension": ".py",
334 | "mimetype": "text/x-python",
335 | "name": "python",
336 | "nbconvert_exporter": "python",
337 | "pygments_lexer": "ipython3",
338 | "version": "3.10.14"
339 | },
340 | "toc": {
341 | "base_numbering": 1,
342 | "nav_menu": {},
343 | "number_sections": true,
344 | "sideBar": true,
345 | "skip_h1_title": false,
346 | "title_cell": "Table of Contents",
347 | "title_sidebar": "Contents",
348 | "toc_cell": false,
349 | "toc_position": {},
350 | "toc_section_display": true,
351 | "toc_window_display": false
352 | }
353 | },
354 | "nbformat": 4,
355 | "nbformat_minor": 5
356 | }
357 |
--------------------------------------------------------------------------------
/examples/rust-example/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | recorder.wav
--------------------------------------------------------------------------------
/examples/rust-example/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rust-example"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | ort = { version = "2.0.0-rc.2", features = ["load-dynamic", "ndarray"] }
8 | ndarray = "0.15"
9 | hound = "3"
10 |
--------------------------------------------------------------------------------
/examples/rust-example/README.md:
--------------------------------------------------------------------------------
1 | # Stream example in Rust
2 | Made after [C++ stream example](https://github.com/snakers4/silero-vad/tree/master/examples/cpp)
3 |
4 | ## Dependencies
5 | - To build Rust crate `ort` you need `cc` installed.
6 |
7 | ## Usage
8 | Just
9 | ```
10 | cargo run
11 | ```
12 | If you run example outside of this repo adjust environment variable
13 | ```
14 | SILERO_MODEL_PATH=/path/to/silero_vad.onnx cargo run
15 | ```
16 | If you need to test against other wav file, not `recorder.wav`, specify it as the first argument
17 | ```
18 | cargo run -- /path/to/audio/file.wav
19 | ```
--------------------------------------------------------------------------------
/examples/rust-example/src/main.rs:
--------------------------------------------------------------------------------
1 | mod silero;
2 | mod utils;
3 | mod vad_iter;
4 |
5 | fn main() {
6 | let model_path = std::env::var("SILERO_MODEL_PATH")
7 | .unwrap_or_else(|_| String::from("../../files/silero_vad.onnx"));
8 | let audio_path = std::env::args()
9 | .nth(1)
10 | .unwrap_or_else(|| String::from("recorder.wav"));
11 | let mut wav_reader = hound::WavReader::open(audio_path).unwrap();
12 | let sample_rate = match wav_reader.spec().sample_rate {
13 | 8000 => utils::SampleRate::EightkHz,
14 | 16000 => utils::SampleRate::SixteenkHz,
15 | _ => panic!("Unsupported sample rate. Expect 8 kHz or 16 kHz."),
16 | };
17 | if wav_reader.spec().sample_format != hound::SampleFormat::Int {
18 | panic!("Unsupported sample format. Expect Int.");
19 | }
20 | let content = wav_reader
21 | .samples()
22 | .filter_map(|x| x.ok())
23 | .collect::>();
24 | assert!(!content.is_empty());
25 | let silero = silero::Silero::new(sample_rate, model_path).unwrap();
26 | let vad_params = utils::VadParams {
27 | sample_rate: sample_rate.into(),
28 | ..Default::default()
29 | };
30 | let mut vad_iterator = vad_iter::VadIter::new(silero, vad_params);
31 | vad_iterator.process(&content).unwrap();
32 | for timestamp in vad_iterator.speeches() {
33 | println!("{}", timestamp);
34 | }
35 | println!("Finished.");
36 | }
37 |
--------------------------------------------------------------------------------
/examples/rust-example/src/silero.rs:
--------------------------------------------------------------------------------
1 | use crate::utils;
2 | use ndarray::{s, Array, Array2, ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
3 | use std::path::Path;
4 |
5 | #[derive(Debug)]
6 | pub struct Silero {
7 | session: ort::Session,
8 | sample_rate: ArrayBase, Dim<[usize; 1]>>,
9 | state: ArrayBase, Dim>,
10 | }
11 |
12 | impl Silero {
13 | pub fn new(
14 | sample_rate: utils::SampleRate,
15 | model_path: impl AsRef,
16 | ) -> Result {
17 | let session = ort::Session::builder()?.commit_from_file(model_path)?;
18 | let state = ArrayD::::zeros([2, 1, 128].as_slice());
19 | let sample_rate = Array::from_shape_vec([1], vec![sample_rate.into()]).unwrap();
20 | Ok(Self {
21 | session,
22 | sample_rate,
23 | state,
24 | })
25 | }
26 |
27 | pub fn reset(&mut self) {
28 | self.state = ArrayD::::zeros([2, 1, 128].as_slice());
29 | }
30 |
31 | pub fn calc_level(&mut self, audio_frame: &[i16]) -> Result {
32 | let data = audio_frame
33 | .iter()
34 | .map(|x| (*x as f32) / (i16::MAX as f32))
35 | .collect::>();
36 | let mut frame = Array2::::from_shape_vec([1, data.len()], data).unwrap();
37 | frame = frame.slice(s![.., ..480]).to_owned();
38 | let inps = ort::inputs![
39 | frame,
40 | std::mem::take(&mut self.state),
41 | self.sample_rate.clone(),
42 | ]?;
43 | let res = self
44 | .session
45 | .run(ort::SessionInputs::ValueSlice::<3>(&inps))?;
46 | self.state = res["stateN"].try_extract_tensor().unwrap().to_owned();
47 | Ok(*res["output"]
48 | .try_extract_raw_tensor::()
49 | .unwrap()
50 | .1
51 | .first()
52 | .unwrap())
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/examples/rust-example/src/utils.rs:
--------------------------------------------------------------------------------
1 | #[derive(Debug, Clone, Copy)]
2 | pub enum SampleRate {
3 | EightkHz,
4 | SixteenkHz,
5 | }
6 |
7 | impl From for i64 {
8 | fn from(value: SampleRate) -> Self {
9 | match value {
10 | SampleRate::EightkHz => 8000,
11 | SampleRate::SixteenkHz => 16000,
12 | }
13 | }
14 | }
15 |
16 | impl From for usize {
17 | fn from(value: SampleRate) -> Self {
18 | match value {
19 | SampleRate::EightkHz => 8000,
20 | SampleRate::SixteenkHz => 16000,
21 | }
22 | }
23 | }
24 |
25 | #[derive(Debug)]
26 | pub struct VadParams {
27 | pub frame_size: usize,
28 | pub threshold: f32,
29 | pub min_silence_duration_ms: usize,
30 | pub speech_pad_ms: usize,
31 | pub min_speech_duration_ms: usize,
32 | pub max_speech_duration_s: f32,
33 | pub sample_rate: usize,
34 | }
35 |
36 | impl Default for VadParams {
37 | fn default() -> Self {
38 | Self {
39 | frame_size: 64,
40 | threshold: 0.5,
41 | min_silence_duration_ms: 0,
42 | speech_pad_ms: 64,
43 | min_speech_duration_ms: 64,
44 | max_speech_duration_s: f32::INFINITY,
45 | sample_rate: 16000,
46 | }
47 | }
48 | }
49 |
50 | #[derive(Debug, Default)]
51 | pub struct TimeStamp {
52 | pub start: i64,
53 | pub end: i64,
54 | }
55 |
56 | impl std::fmt::Display for TimeStamp {
57 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 | write!(f, "[start:{:08}, end:{:08}]", self.start, self.end)
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/examples/rust-example/src/vad_iter.rs:
--------------------------------------------------------------------------------
1 | use crate::{silero, utils};
2 |
3 | const DEBUG_SPEECH_PROB: bool = true;
4 | #[derive(Debug)]
5 | pub struct VadIter {
6 | silero: silero::Silero,
7 | params: Params,
8 | state: State,
9 | }
10 |
11 | impl VadIter {
12 | pub fn new(silero: silero::Silero, params: utils::VadParams) -> Self {
13 | Self {
14 | silero,
15 | params: Params::from(params),
16 | state: State::new(),
17 | }
18 | }
19 |
20 | pub fn process(&mut self, samples: &[i16]) -> Result<(), ort::Error> {
21 | self.reset_states();
22 | for audio_frame in samples.chunks_exact(self.params.frame_size_samples) {
23 | let speech_prob: f32 = self.silero.calc_level(audio_frame)?;
24 | self.state.update(&self.params, speech_prob);
25 | }
26 | self.state.check_for_last_speech(samples.len());
27 | Ok(())
28 | }
29 |
30 | pub fn speeches(&self) -> &[utils::TimeStamp] {
31 | &self.state.speeches
32 | }
33 | }
34 |
35 | impl VadIter {
36 | fn reset_states(&mut self) {
37 | self.silero.reset();
38 | self.state = State::new()
39 | }
40 | }
41 |
42 | #[allow(unused)]
43 | #[derive(Debug)]
44 | struct Params {
45 | frame_size: usize,
46 | threshold: f32,
47 | min_silence_duration_ms: usize,
48 | speech_pad_ms: usize,
49 | min_speech_duration_ms: usize,
50 | max_speech_duration_s: f32,
51 | sample_rate: usize,
52 | sr_per_ms: usize,
53 | frame_size_samples: usize,
54 | min_speech_samples: usize,
55 | speech_pad_samples: usize,
56 | max_speech_samples: f32,
57 | min_silence_samples: usize,
58 | min_silence_samples_at_max_speech: usize,
59 | }
60 |
61 | impl From for Params {
62 | fn from(value: utils::VadParams) -> Self {
63 | let frame_size = value.frame_size;
64 | let threshold = value.threshold;
65 | let min_silence_duration_ms = value.min_silence_duration_ms;
66 | let speech_pad_ms = value.speech_pad_ms;
67 | let min_speech_duration_ms = value.min_speech_duration_ms;
68 | let max_speech_duration_s = value.max_speech_duration_s;
69 | let sample_rate = value.sample_rate;
70 | let sr_per_ms = sample_rate / 1000;
71 | let frame_size_samples = frame_size * sr_per_ms;
72 | let min_speech_samples = sr_per_ms * min_speech_duration_ms;
73 | let speech_pad_samples = sr_per_ms * speech_pad_ms;
74 | let max_speech_samples = sample_rate as f32 * max_speech_duration_s
75 | - frame_size_samples as f32
76 | - 2.0 * speech_pad_samples as f32;
77 | let min_silence_samples = sr_per_ms * min_silence_duration_ms;
78 | let min_silence_samples_at_max_speech = sr_per_ms * 98;
79 | Self {
80 | frame_size,
81 | threshold,
82 | min_silence_duration_ms,
83 | speech_pad_ms,
84 | min_speech_duration_ms,
85 | max_speech_duration_s,
86 | sample_rate,
87 | sr_per_ms,
88 | frame_size_samples,
89 | min_speech_samples,
90 | speech_pad_samples,
91 | max_speech_samples,
92 | min_silence_samples,
93 | min_silence_samples_at_max_speech,
94 | }
95 | }
96 | }
97 |
98 | #[derive(Debug, Default)]
99 | struct State {
100 | current_sample: usize,
101 | temp_end: usize,
102 | next_start: usize,
103 | prev_end: usize,
104 | triggered: bool,
105 | current_speech: utils::TimeStamp,
106 | speeches: Vec,
107 | }
108 |
109 | impl State {
110 | fn new() -> Self {
111 | Default::default()
112 | }
113 |
114 | fn update(&mut self, params: &Params, speech_prob: f32) {
115 | self.current_sample += params.frame_size_samples;
116 | if speech_prob > params.threshold {
117 | if self.temp_end != 0 {
118 | self.temp_end = 0;
119 | if self.next_start < self.prev_end {
120 | self.next_start = self
121 | .current_sample
122 | .saturating_sub(params.frame_size_samples)
123 | }
124 | }
125 | if !self.triggered {
126 | self.debug(speech_prob, params, "start");
127 | self.triggered = true;
128 | self.current_speech.start =
129 | self.current_sample as i64 - params.frame_size_samples as i64;
130 | }
131 | return;
132 | }
133 | if self.triggered
134 | && (self.current_sample as i64 - self.current_speech.start) as f32
135 | > params.max_speech_samples
136 | {
137 | if self.prev_end > 0 {
138 | self.current_speech.end = self.prev_end as _;
139 | self.take_speech();
140 | if self.next_start < self.prev_end {
141 | self.triggered = false
142 | } else {
143 | self.current_speech.start = self.next_start as _;
144 | }
145 | self.prev_end = 0;
146 | self.next_start = 0;
147 | self.temp_end = 0;
148 | } else {
149 | self.current_speech.end = self.current_sample as _;
150 | self.take_speech();
151 | self.prev_end = 0;
152 | self.next_start = 0;
153 | self.temp_end = 0;
154 | self.triggered = false;
155 | }
156 | return;
157 | }
158 | if speech_prob >= (params.threshold - 0.15) && (speech_prob < params.threshold) {
159 | if self.triggered {
160 | self.debug(speech_prob, params, "speaking")
161 | } else {
162 | self.debug(speech_prob, params, "silence")
163 | }
164 | }
165 | if self.triggered && speech_prob < (params.threshold - 0.15) {
166 | self.debug(speech_prob, params, "end");
167 | if self.temp_end == 0 {
168 | self.temp_end = self.current_sample;
169 | }
170 | if self.current_sample.saturating_sub(self.temp_end)
171 | > params.min_silence_samples_at_max_speech
172 | {
173 | self.prev_end = self.temp_end;
174 | }
175 | if self.current_sample.saturating_sub(self.temp_end) >= params.min_silence_samples {
176 | self.current_speech.end = self.temp_end as _;
177 | if self.current_speech.end - self.current_speech.start
178 | > params.min_speech_samples as _
179 | {
180 | self.take_speech();
181 | self.prev_end = 0;
182 | self.next_start = 0;
183 | self.temp_end = 0;
184 | self.triggered = false;
185 | }
186 | }
187 | }
188 | }
189 |
190 | fn take_speech(&mut self) {
191 | self.speeches.push(std::mem::take(&mut self.current_speech)); // current speech becomes TimeStamp::default() due to take()
192 | }
193 |
194 | fn check_for_last_speech(&mut self, last_sample: usize) {
195 | if self.current_speech.start > 0 {
196 | self.current_speech.end = last_sample as _;
197 | self.take_speech();
198 | self.prev_end = 0;
199 | self.next_start = 0;
200 | self.temp_end = 0;
201 | self.triggered = false;
202 | }
203 | }
204 |
205 | fn debug(&self, speech_prob: f32, params: &Params, title: &str) {
206 | if DEBUG_SPEECH_PROB {
207 | let speech = self.current_sample as f32
208 | - params.frame_size_samples as f32
209 | - if title == "end" {
210 | params.speech_pad_samples
211 | } else {
212 | 0
213 | } as f32; // minus window_size_samples to get precise start time point.
214 | println!(
215 | "[{:10}: {:.3} s ({:.3}) {:8}]",
216 | title,
217 | speech / params.sample_rate as f32,
218 | speech_prob,
219 | self.current_sample - params.frame_size_samples,
220 | );
221 | }
222 | }
223 | }
224 |
--------------------------------------------------------------------------------
/files/silero_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/files/silero_logo.jpg
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | dependencies = ['torch', 'torchaudio']
2 | import torch
3 | import os
4 | import sys
5 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
6 | from silero_vad.utils_vad import (init_jit_model,
7 | get_speech_timestamps,
8 | save_audio,
9 | read_audio,
10 | VADIterator,
11 | collect_chunks,
12 | OnnxWrapper)
13 |
14 |
15 | def versiontuple(v):
16 | splitted = v.split('+')[0].split(".")
17 | version_list = []
18 | for i in splitted:
19 | try:
20 | version_list.append(int(i))
21 | except:
22 | version_list.append(0)
23 | return tuple(version_list)
24 |
25 |
26 | def silero_vad(onnx=False, force_onnx_cpu=False, opset_version=16):
27 | """Silero Voice Activity Detector
28 | Returns a model with a set of utils
29 | Please see https://github.com/snakers4/silero-vad for usage examples
30 | """
31 | available_ops = [15, 16]
32 | if onnx and opset_version not in available_ops:
33 | raise Exception(f'Available ONNX opset_version: {available_ops}')
34 |
35 | if not onnx:
36 | installed_version = torch.__version__
37 | supported_version = '1.12.0'
38 | if versiontuple(installed_version) < versiontuple(supported_version):
39 | raise Exception(f'Please install torch {supported_version} or greater ({installed_version} installed)')
40 |
41 | model_dir = os.path.join(os.path.dirname(__file__), 'src', 'silero_vad', 'data')
42 | if onnx:
43 | if opset_version == 16:
44 | model_name = 'silero_vad.onnx'
45 | else:
46 | model_name = f'silero_vad_16k_op{opset_version}.onnx'
47 | model = OnnxWrapper(os.path.join(model_dir, model_name), force_onnx_cpu)
48 | else:
49 | model = init_jit_model(os.path.join(model_dir, 'silero_vad.jit'))
50 | utils = (get_speech_timestamps,
51 | save_audio,
52 | read_audio,
53 | VADIterator,
54 | collect_chunks)
55 |
56 | return model, utils
57 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 | [project]
5 | name = "silero-vad"
6 | version = "5.1.2"
7 | authors = [
8 | {name="Silero Team", email="hello@silero.ai"},
9 | ]
10 | description = "Voice Activity Detector (VAD) by Silero"
11 | readme = "README.md"
12 | requires-python = ">=3.8"
13 | classifiers = [
14 | "Development Status :: 5 - Production/Stable",
15 | "License :: OSI Approved :: MIT License",
16 | "Operating System :: OS Independent",
17 | "Intended Audience :: Science/Research",
18 | "Intended Audience :: Developers",
19 | "Programming Language :: Python :: 3.8",
20 | "Programming Language :: Python :: 3.9",
21 | "Programming Language :: Python :: 3.10",
22 | "Programming Language :: Python :: 3.11",
23 | "Programming Language :: Python :: 3.12",
24 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
25 | "Topic :: Scientific/Engineering",
26 | ]
27 | dependencies = [
28 | "torch>=1.12.0",
29 | "torchaudio>=0.12.0",
30 | "onnxruntime>=1.16.1",
31 | ]
32 |
33 | [project.urls]
34 | Homepage = "https://github.com/snakers4/silero-vad"
35 | Issues = "https://github.com/snakers4/silero-vad/issues"
36 |
--------------------------------------------------------------------------------
/silero-vad.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "heading_collapsed": true,
7 | "id": "62A6F_072Fwq"
8 | },
9 | "source": [
10 | "## Install Dependencies"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "hidden": true,
18 | "id": "5w5AkskZ2Fwr"
19 | },
20 | "outputs": [],
21 | "source": [
22 | "#@title Install and Import Dependencies\n",
23 | "\n",
24 | "# this assumes that you have a relevant version of PyTorch installed\n",
25 | "!pip install -q torchaudio\n",
26 | "\n",
27 | "SAMPLING_RATE = 16000\n",
28 | "\n",
29 | "import torch\n",
30 | "torch.set_num_threads(1)\n",
31 | "\n",
32 | "from IPython.display import Audio\n",
33 | "from pprint import pprint\n",
34 | "# download example\n",
35 | "torch.hub.download_url_to_file('https://models.silero.ai/vad_models/en.wav', 'en_example.wav')"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {
42 | "id": "pSifus5IilRp"
43 | },
44 | "outputs": [],
45 | "source": [
46 | "USE_PIP = True # download model using pip package or torch.hub\n",
47 | "USE_ONNX = False # change this to True if you want to test onnx model\n",
48 | "if USE_ONNX:\n",
49 | " !pip install -q onnxruntime\n",
50 | "if USE_PIP:\n",
51 | " !pip install -q silero-vad\n",
52 | " from silero_vad import (load_silero_vad,\n",
53 | " read_audio,\n",
54 | " get_speech_timestamps,\n",
55 | " save_audio,\n",
56 | " VADIterator,\n",
57 | " collect_chunks)\n",
58 | " model = load_silero_vad(onnx=USE_ONNX)\n",
59 | "else:\n",
60 | " model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
61 | " model='silero_vad',\n",
62 | " force_reload=True,\n",
63 | " onnx=USE_ONNX)\n",
64 | "\n",
65 | " (get_speech_timestamps,\n",
66 | " save_audio,\n",
67 | " read_audio,\n",
68 | " VADIterator,\n",
69 | " collect_chunks) = utils"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {
75 | "id": "fXbbaUO3jsrw"
76 | },
77 | "source": [
78 | "## Speech timestapms from full audio"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {
85 | "id": "aI_eydBPjsrx"
86 | },
87 | "outputs": [],
88 | "source": [
89 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
90 | "# get speech timestamps from full audio file\n",
91 | "speech_timestamps = get_speech_timestamps(wav, model, sampling_rate=SAMPLING_RATE)\n",
92 | "pprint(speech_timestamps)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {
99 | "id": "OuEobLchjsry"
100 | },
101 | "outputs": [],
102 | "source": [
103 | "# merge all speech chunks to one audio\n",
104 | "save_audio('only_speech.wav',\n",
105 | " collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
106 | "Audio('only_speech.wav')"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "metadata": {
112 | "id": "zeO1xCqxUC6w"
113 | },
114 | "source": [
115 | "## Entire audio inference"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {
122 | "id": "LjZBcsaTT7Mk"
123 | },
124 | "outputs": [],
125 | "source": [
126 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
127 | "# audio is being splitted into 31.25 ms long pieces\n",
128 | "# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
129 | "predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "id": "iDKQbVr8jsry"
136 | },
137 | "source": [
138 | "## Stream imitation example"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {
145 | "id": "q-lql_2Wjsry"
146 | },
147 | "outputs": [],
148 | "source": [
149 | "## using VADIterator class\n",
150 | "\n",
151 | "vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
152 | "wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
153 | "\n",
154 | "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
155 | "for i in range(0, len(wav), window_size_samples):\n",
156 | " chunk = wav[i: i+ window_size_samples]\n",
157 | " if len(chunk) < window_size_samples:\n",
158 | " break\n",
159 | " speech_dict = vad_iterator(chunk, return_seconds=True)\n",
160 | " if speech_dict:\n",
161 | " print(speech_dict, end=' ')\n",
162 | "vad_iterator.reset_states() # reset model states after each audio"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "metadata": {
169 | "id": "BX3UgwwB2Fwv"
170 | },
171 | "outputs": [],
172 | "source": [
173 | "## just probabilities\n",
174 | "\n",
175 | "wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
176 | "speech_probs = []\n",
177 | "window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
178 | "for i in range(0, len(wav), window_size_samples):\n",
179 | " chunk = wav[i: i+ window_size_samples]\n",
180 | " if len(chunk) < window_size_samples:\n",
181 | " break\n",
182 | " speech_prob = model(chunk, SAMPLING_RATE).item()\n",
183 | " speech_probs.append(speech_prob)\n",
184 | "vad_iterator.reset_states() # reset model states after each audio\n",
185 | "\n",
186 | "print(speech_probs[:10]) # first 10 chunks predicts"
187 | ]
188 | }
189 | ],
190 | "metadata": {
191 | "colab": {
192 | "name": "silero-vad.ipynb",
193 | "provenance": []
194 | },
195 | "kernelspec": {
196 | "display_name": "Python 3",
197 | "language": "python",
198 | "name": "python3"
199 | },
200 | "language_info": {
201 | "codemirror_mode": {
202 | "name": "ipython",
203 | "version": 3
204 | },
205 | "file_extension": ".py",
206 | "mimetype": "text/x-python",
207 | "name": "python",
208 | "nbconvert_exporter": "python",
209 | "pygments_lexer": "ipython3",
210 | "version": "3.8.8"
211 | },
212 | "toc": {
213 | "base_numbering": 1,
214 | "nav_menu": {},
215 | "number_sections": true,
216 | "sideBar": true,
217 | "skip_h1_title": false,
218 | "title_cell": "Table of Contents",
219 | "title_sidebar": "Contents",
220 | "toc_cell": false,
221 | "toc_position": {},
222 | "toc_section_display": true,
223 | "toc_window_display": false
224 | }
225 | },
226 | "nbformat": 4,
227 | "nbformat_minor": 0
228 | }
229 |
--------------------------------------------------------------------------------
/src/silero_vad/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 | try:
3 | __version__ = version(__name__)
4 | except:
5 | pass
6 |
7 | from silero_vad.model import load_silero_vad
8 | from silero_vad.utils_vad import (get_speech_timestamps,
9 | save_audio,
10 | read_audio,
11 | VADIterator,
12 | collect_chunks)
--------------------------------------------------------------------------------
/src/silero_vad/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/__init__.py
--------------------------------------------------------------------------------
/src/silero_vad/data/silero_vad.jit:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad.jit
--------------------------------------------------------------------------------
/src/silero_vad/data/silero_vad.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad.onnx
--------------------------------------------------------------------------------
/src/silero_vad/data/silero_vad_16k_op15.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad_16k_op15.onnx
--------------------------------------------------------------------------------
/src/silero_vad/data/silero_vad_half.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/data/silero_vad_half.onnx
--------------------------------------------------------------------------------
/src/silero_vad/model.py:
--------------------------------------------------------------------------------
1 | from .utils_vad import init_jit_model, OnnxWrapper
2 | import torch
3 | torch.set_num_threads(1)
4 |
5 |
6 | def load_silero_vad(onnx=False, opset_version=16):
7 | available_ops = [15, 16]
8 | if onnx and opset_version not in available_ops:
9 | raise Exception(f'Available ONNX opset_version: {available_ops}')
10 |
11 | if onnx:
12 | if opset_version == 16:
13 | model_name = 'silero_vad.onnx'
14 | else:
15 | model_name = f'silero_vad_16k_op{opset_version}.onnx'
16 | else:
17 | model_name = 'silero_vad.jit'
18 | package_path = "silero_vad.data"
19 |
20 | try:
21 | import importlib_resources as impresources
22 | model_file_path = str(impresources.files(package_path).joinpath(model_name))
23 | except:
24 | from importlib import resources as impresources
25 | try:
26 | with impresources.path(package_path, model_name) as f:
27 | model_file_path = f
28 | except:
29 | model_file_path = str(impresources.files(package_path).joinpath(model_name))
30 |
31 | if onnx:
32 | model = OnnxWrapper(model_file_path, force_onnx_cpu=True)
33 | else:
34 | model = init_jit_model(model_file_path)
35 |
36 | return model
37 |
--------------------------------------------------------------------------------
/tuning/README.md:
--------------------------------------------------------------------------------
1 | # Тюнинг Silero-VAD модели
2 |
3 | > Код тюнинга создан при поддержке Фонда содействия инновациям в рамках федерального проекта «Искусственный
4 | интеллект» национальной программы «Цифровая экономика Российской Федерации».
5 |
6 | Тюнинг используется для улучшения качества детекции речи Silero-VAD модели на кастомных данных.
7 |
8 | ## Зависимости
9 | Следующие зависимости используются при тюнинге VAD модели:
10 | - `torchaudio>=0.12.0`
11 | - `omegaconf>=2.3.0`
12 | - `sklearn>=1.2.0`
13 | - `torch>=1.12.0`
14 | - `pandas>=2.2.2`
15 | - `tqdm`
16 |
17 | ## Подготовка данных
18 |
19 | Датафреймы для тюнинга должны быть подготовлены и сохранены в формате `.feather`. Следующие колонки в `.feather` файлах тренировки и валидации являются обязательными:
20 | - **audio_path** - абсолютный путь до аудиофайла в дисковой системе. Аудиофайлы должны представлять собой `PCM` данные, предпочтительно в форматах `.wav` или `.opus` (иные популярные форматы аудио тоже поддерживаются). Для ускорения темпа дообучения рекомендуется предварительно выполнить ресемплинг аудиофайлов (изменить частоту дискретизации) до 16000 Гц;
21 | - **speech_ts** - разметка для соответствующего аудиофайла. Список, состоящий из словарей формата `{'start': START_SEC, 'end': 'END_SEC'}`, где `START_SEC` и `END_SEC` - время начало и конца речевого отрезка в секундах соответственно. Для качественного дообучения рекомендуется использовать разметку с точностью до 30 миллисекунд.
22 |
23 | Чем больше данных используется на этапе дообучения, тем эффективнее показывает себя адаптированная модель на целевом домене. Длина аудио не ограничена, т.к. каждое аудио будет обрезано до `max_train_length_sec` секунд перед подачей в нейросеть. Длинные аудио лучше предварительно порезать на кусочки длины `max_train_length_sec`.
24 |
25 | Пример `.feather` датафрейма можно посмотреть в файле `example_dataframe.feather`
26 |
27 | ## Файл конфигурации `config.yml`
28 |
29 | Файл конфигурации `config.yml` содержит пути до обучающей и валидационной выборки, а также параметры дообучения:
30 | - `train_dataset_path` - абсолютный путь до тренировочного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
31 | - `val_dataset_path` - абсолютный путь до валидационного датафрейма в формате `.feather`. Должен содержать колонки `audio_path` и `speech_ts`, описанные в пункте "Подготовка данных". Пример устройства датафрейма можно посмотреть в `example_dataframe.feather`;
32 | - `jit_model_path` - абсолютный путь до Silero-VAD модели в формате `.jit`. Если оставить это поле пустым, то модель будет загружена из репозитория в зависимости от значения поля `use_torchhub`
33 | - `use_torchhub` - Если `True`, то модель для дообучения будет загружена с помощью torch.hub. Если `False`, то модель для дообучения будет загружена с помощью библиотеки silero-vad (необходимо заранее установить командой `pip install silero-vad`);
34 | - `tune_8k` - данный параметр отвечает, какую голову Silero-VAD дообучать. Если `True`, дообучаться будет голова с 8000 Гц частотой дискретизации, иначе с 16000 Гц;
35 | - `model_save_path` - путь сохранения добученной модели;
36 | - `noise_loss` - коэффициент лосса, применяемый для неречевых окон аудио;
37 | - `max_train_length_sec` - максимальная длина аудио в секундах на этапе дообучения. Более длительные аудио будут обрезаны до этого показателя;
38 | - `aug_prob` - вероятность применения аугментаций к аудиофайлу на этапе дообучения;
39 | - `learning_rate` - темп дообучения;
40 | - `batch_size` - размер батча при дообучении и валидации;
41 | - `num_workers` - количество потоков, используемых для загрузки данных;
42 | - `num_epochs` - количество эпох дообучения. За одну эпоху прогоняются все тренировочные данные;
43 | - `device` - `cpu` или `cuda`.
44 |
45 | ## Дообучение
46 |
47 | Дообучение запускается командой
48 |
49 | `python tune.py`
50 |
51 | Длится в течение `num_epochs`, лучший чекпоинт по показателю ROC-AUC на валидационной выборке будет сохранен в `model_save_path` в формате jit.
52 |
53 | ## Поиск пороговых значений
54 |
55 | Порог на вход и порог на выход можно подобрать, используя команду
56 |
57 | `python search_thresholds`
58 |
59 | Данный скрипт использует файл конфигурации, описанный выше. Указанная в конфигурации модель будет использована для поиска оптимальных порогов на валидационном датасете.
60 |
61 | ## Цитирование
62 |
63 | ```
64 | @misc{Silero VAD,
65 | author = {Silero Team},
66 | title = {Silero VAD: pre-trained enterprise-grade Voice Activity Detector (VAD), Number Detector and Language Classifier},
67 | year = {2024},
68 | publisher = {GitHub},
69 | journal = {GitHub repository},
70 | howpublished = {\url{https://github.com/snakers4/silero-vad}},
71 | commit = {insert_some_commit_here},
72 | email = {hello@silero.ai}
73 | }
74 | ```
75 |
--------------------------------------------------------------------------------
/tuning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/tuning/__init__.py
--------------------------------------------------------------------------------
/tuning/config.yml:
--------------------------------------------------------------------------------
1 | jit_model_path: '' # путь до Silero-VAD модели в формате jit, эта модель будет использована для дообучения. Если оставить поле пустым, то модель будет загружена автоматически
2 | use_torchhub: True # jit модель будет загружена через torchhub, если True, или через pip, если False
3 |
4 | tune_8k: False # дообучает 16к голову, если False, и 8к голову, если True
5 | train_dataset_path: 'train_dataset_path.feather' # путь до датасета в формате feather для дообучения, подробности в README
6 | val_dataset_path: 'val_dataset_path.feather' # путь до датасета в формате feather для валидации, подробности в README
7 | model_save_path: 'model_save_path.jit' # путь сохранения дообученной модели
8 |
9 | noise_loss: 0.5 # коэффициент, применяемый к лоссу на неречевых окнах
10 | max_train_length_sec: 8 # во время тюнинга аудио длиннее будут обрезаны до данного значения
11 | aug_prob: 0.4 # вероятность применения аугментаций к аудио в процессе дообучения
12 |
13 | learning_rate: 5e-4 # темп дообучения модели
14 | batch_size: 128 # размер батча при дообучении и валидации
15 | num_workers: 4 # количество потоков, используемых для даталоадеров
16 | num_epochs: 20 # количество эпох дообучения, 1 эпоха = полный прогон тренировочных данных
17 | device: 'cuda' # cpu или cuda, на чем будет производится дообучение
--------------------------------------------------------------------------------
/tuning/example_dataframe.feather:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/snakers4/silero-vad/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/tuning/example_dataframe.feather
--------------------------------------------------------------------------------
/tuning/search_thresholds.py:
--------------------------------------------------------------------------------
1 | from utils import init_jit_model, predict, calculate_best_thresholds, SileroVadDataset, SileroVadPadder
2 | from omegaconf import OmegaConf
3 | import torch
4 | torch.set_num_threads(1)
5 |
6 | if __name__ == '__main__':
7 | config = OmegaConf.load('config.yml')
8 |
9 | loader = torch.utils.data.DataLoader(SileroVadDataset(config, mode='val'),
10 | batch_size=config.batch_size,
11 | collate_fn=SileroVadPadder,
12 | num_workers=config.num_workers)
13 |
14 | if config.jit_model_path:
15 | print(f'Loading model from the local folder: {config.jit_model_path}')
16 | model = init_jit_model(config.jit_model_path, device=config.device)
17 | else:
18 | if config.use_torchhub:
19 | print('Loading model using torch.hub')
20 | model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
21 | model='silero_vad',
22 | onnx=False,
23 | force_reload=True)
24 | else:
25 | print('Loading model using silero-vad library')
26 | from silero_vad import load_silero_vad
27 | model = load_silero_vad(onnx=False)
28 |
29 | print('Model loaded')
30 | model.to(config.device)
31 |
32 | print('Making predicts...')
33 | all_predicts, all_gts = predict(model, loader, config.device, sr=8000 if config.tune_8k else 16000)
34 | print('Calculating thresholds...')
35 | best_ths_enter, best_ths_exit, best_acc = calculate_best_thresholds(all_predicts, all_gts)
36 | print(f'Best threshold: {best_ths_enter}\nBest exit threshold: {best_ths_exit}\nBest accuracy: {best_acc}')
37 |
--------------------------------------------------------------------------------
/tuning/tune.py:
--------------------------------------------------------------------------------
1 | from utils import SileroVadDataset, SileroVadPadder, VADDecoderRNNJIT, train, validate, init_jit_model
2 | from omegaconf import OmegaConf
3 | import torch.nn as nn
4 | import torch
5 |
6 |
7 | if __name__ == '__main__':
8 | config = OmegaConf.load('config.yml')
9 |
10 | train_dataset = SileroVadDataset(config, mode='train')
11 | train_loader = torch.utils.data.DataLoader(train_dataset,
12 | batch_size=config.batch_size,
13 | collate_fn=SileroVadPadder,
14 | num_workers=config.num_workers)
15 |
16 | val_dataset = SileroVadDataset(config, mode='val')
17 | val_loader = torch.utils.data.DataLoader(val_dataset,
18 | batch_size=config.batch_size,
19 | collate_fn=SileroVadPadder,
20 | num_workers=config.num_workers)
21 |
22 | if config.jit_model_path:
23 | print(f'Loading model from the local folder: {config.jit_model_path}')
24 | model = init_jit_model(config.jit_model_path, device=config.device)
25 | else:
26 | if config.use_torchhub:
27 | print('Loading model using torch.hub')
28 | model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
29 | model='silero_vad',
30 | onnx=False,
31 | force_reload=True)
32 | else:
33 | print('Loading model using silero-vad library')
34 | from silero_vad import load_silero_vad
35 | model = load_silero_vad(onnx=False)
36 |
37 | print('Model loaded')
38 | model.to(config.device)
39 | decoder = VADDecoderRNNJIT().to(config.device)
40 | decoder.load_state_dict(model._model_8k.decoder.state_dict() if config.tune_8k else model._model.decoder.state_dict())
41 | decoder.train()
42 | params = decoder.parameters()
43 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, params),
44 | lr=config.learning_rate)
45 | criterion = nn.BCELoss(reduction='none')
46 |
47 | best_val_roc = 0
48 | for i in range(config.num_epochs):
49 | print(f'Starting epoch {i + 1}')
50 | train_loss = train(config, train_loader, model, decoder, criterion, optimizer, config.device)
51 | val_loss, val_roc = validate(config, val_loader, model, decoder, criterion, config.device)
52 | print(f'Metrics after epoch {i + 1}:\n'
53 | f'\tTrain loss: {round(train_loss, 3)}\n',
54 | f'\tValidation loss: {round(val_loss, 3)}\n'
55 | f'\tValidation ROC-AUC: {round(val_roc, 3)}')
56 |
57 | if val_roc > best_val_roc:
58 | print('New best ROC-AUC, saving model')
59 | best_val_roc = val_roc
60 | if config.tune_8k:
61 | model._model_8k.decoder.load_state_dict(decoder.state_dict())
62 | else:
63 | model._model.decoder.load_state_dict(decoder.state_dict())
64 | torch.jit.save(model, config.model_save_path)
65 | print('Done')
66 |
--------------------------------------------------------------------------------
/tuning/utils.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import roc_auc_score, accuracy_score
2 | from torch.utils.data import Dataset
3 | import torch.nn as nn
4 | from tqdm import tqdm
5 | import pandas as pd
6 | import numpy as np
7 | import torchaudio
8 | import warnings
9 | import random
10 | import torch
11 | import gc
12 | warnings.filterwarnings('ignore')
13 |
14 |
15 | def read_audio(path: str,
16 | sampling_rate: int = 16000,
17 | normalize=False):
18 |
19 | wav, sr = torchaudio.load(path)
20 |
21 | if wav.size(0) > 1:
22 | wav = wav.mean(dim=0, keepdim=True)
23 |
24 | if sampling_rate:
25 | if sr != sampling_rate:
26 | transform = torchaudio.transforms.Resample(orig_freq=sr,
27 | new_freq=sampling_rate)
28 | wav = transform(wav)
29 | sr = sampling_rate
30 |
31 | if normalize and wav.abs().max() != 0:
32 | wav = wav / wav.abs().max()
33 |
34 | return wav.squeeze(0)
35 |
36 |
37 | def build_audiomentations_augs(p):
38 | from audiomentations import SomeOf, AirAbsorption, BandPassFilter, BandStopFilter, ClippingDistortion, HighPassFilter, HighShelfFilter, \
39 | LowPassFilter, LowShelfFilter, Mp3Compression, PeakingFilter, PitchShift, RoomSimulator, SevenBandParametricEQ, \
40 | Aliasing, AddGaussianNoise
41 | transforms = [Aliasing(p=1),
42 | AddGaussianNoise(p=1),
43 | AirAbsorption(p=1),
44 | BandPassFilter(p=1),
45 | BandStopFilter(p=1),
46 | ClippingDistortion(p=1),
47 | HighPassFilter(p=1),
48 | HighShelfFilter(p=1),
49 | LowPassFilter(p=1),
50 | LowShelfFilter(p=1),
51 | Mp3Compression(p=1),
52 | PeakingFilter(p=1),
53 | PitchShift(p=1),
54 | RoomSimulator(p=1, leave_length_unchanged=True),
55 | SevenBandParametricEQ(p=1)]
56 | tr = SomeOf((1, 3), transforms=transforms, p=p)
57 | return tr
58 |
59 |
60 | class SileroVadDataset(Dataset):
61 | def __init__(self,
62 | config,
63 | mode='train'):
64 |
65 | self.num_samples = 512 # constant, do not change
66 | self.sr = 16000 # constant, do not change
67 |
68 | self.resample_to_8k = config.tune_8k
69 | self.noise_loss = config.noise_loss
70 | self.max_train_length_sec = config.max_train_length_sec
71 | self.max_train_length_samples = config.max_train_length_sec * self.sr
72 |
73 | assert self.max_train_length_samples % self.num_samples == 0
74 | assert mode in ['train', 'val']
75 |
76 | dataset_path = config.train_dataset_path if mode == 'train' else config.val_dataset_path
77 | self.dataframe = pd.read_feather(dataset_path).reset_index(drop=True)
78 | self.index_dict = self.dataframe.to_dict('index')
79 | self.mode = mode
80 | print(f'DATASET SIZE : {len(self.dataframe)}')
81 |
82 | if mode == 'train':
83 | self.augs = build_audiomentations_augs(p=config.aug_prob)
84 | else:
85 | self.augs = None
86 |
87 | def __getitem__(self, idx):
88 | idx = None if self.mode == 'train' else idx
89 | wav, gt, mask = self.load_speech_sample(idx)
90 |
91 | if self.mode == 'train':
92 | wav = self.add_augs(wav)
93 | if len(wav) > self.max_train_length_samples:
94 | wav = wav[:self.max_train_length_samples]
95 | gt = gt[:int(self.max_train_length_samples / self.num_samples)]
96 | mask = mask[:int(self.max_train_length_samples / self.num_samples)]
97 |
98 | wav = torch.FloatTensor(wav)
99 | if self.resample_to_8k:
100 | transform = torchaudio.transforms.Resample(orig_freq=self.sr,
101 | new_freq=8000)
102 | wav = transform(wav)
103 | return wav, torch.FloatTensor(gt), torch.from_numpy(mask)
104 |
105 | def __len__(self):
106 | return len(self.index_dict)
107 |
108 | def load_speech_sample(self, idx=None):
109 | if idx is None:
110 | idx = random.randint(0, len(self.index_dict) - 1)
111 | wav = read_audio(self.index_dict[idx]['audio_path'], self.sr).numpy()
112 |
113 | if len(wav) % self.num_samples != 0:
114 | pad_num = self.num_samples - (len(wav) % (self.num_samples))
115 | wav = np.pad(wav, (0, pad_num), 'constant', constant_values=0)
116 |
117 | gt, mask = self.get_ground_truth_annotated(self.index_dict[idx]['speech_ts'], len(wav))
118 |
119 | assert len(gt) == len(wav) / self.num_samples
120 |
121 | mask[gt == 0]
122 |
123 | return wav, gt, mask
124 |
125 | def get_ground_truth_annotated(self, annotation, audio_length_samples):
126 | gt = np.zeros(audio_length_samples)
127 |
128 | for i in annotation:
129 | gt[int(i['start'] * self.sr): int(i['end'] * self.sr)] = 1
130 |
131 | squeezed_predicts = np.average(gt.reshape(-1, self.num_samples), axis=1)
132 | squeezed_predicts = (squeezed_predicts > 0.5).astype(int)
133 | mask = np.ones(len(squeezed_predicts))
134 | mask[squeezed_predicts == 0] = self.noise_loss
135 | return squeezed_predicts, mask
136 |
137 | def add_augs(self, wav):
138 | while True:
139 | try:
140 | wav_aug = self.augs(wav, self.sr)
141 | if np.isnan(wav_aug.max()) or np.isnan(wav_aug.min()):
142 | return wav
143 | return wav_aug
144 | except Exception as e:
145 | continue
146 |
147 |
148 | def SileroVadPadder(batch):
149 | wavs = [batch[i][0] for i in range(len(batch))]
150 | labels = [batch[i][1] for i in range(len(batch))]
151 | masks = [batch[i][2] for i in range(len(batch))]
152 |
153 | wavs = torch.nn.utils.rnn.pad_sequence(
154 | wavs, batch_first=True, padding_value=0)
155 |
156 | labels = torch.nn.utils.rnn.pad_sequence(
157 | labels, batch_first=True, padding_value=0)
158 |
159 | masks = torch.nn.utils.rnn.pad_sequence(
160 | masks, batch_first=True, padding_value=0)
161 |
162 | return wavs, labels, masks
163 |
164 |
165 | class VADDecoderRNNJIT(nn.Module):
166 |
167 | def __init__(self):
168 | super(VADDecoderRNNJIT, self).__init__()
169 |
170 | self.rnn = nn.LSTMCell(128, 128)
171 | self.decoder = nn.Sequential(nn.Dropout(0.1),
172 | nn.ReLU(),
173 | nn.Conv1d(128, 1, kernel_size=1),
174 | nn.Sigmoid())
175 |
176 | def forward(self, x, state=torch.zeros(0)):
177 | x = x.squeeze(-1)
178 | if len(state):
179 | h, c = self.rnn(x, (state[0], state[1]))
180 | else:
181 | h, c = self.rnn(x)
182 |
183 | x = h.unsqueeze(-1).float()
184 | state = torch.stack([h, c])
185 | x = self.decoder(x)
186 | return x, state
187 |
188 |
189 | class AverageMeter(object):
190 | """Computes and stores the average and current value"""
191 |
192 | def __init__(self):
193 | self.reset()
194 |
195 | def reset(self):
196 | self.val = 0
197 | self.avg = 0
198 | self.sum = 0
199 | self.count = 0
200 |
201 | def update(self, val, n=1):
202 | self.val = val
203 | self.sum += val * n
204 | self.count += n
205 | self.avg = self.sum / self.count
206 |
207 |
208 | def train(config,
209 | loader,
210 | jit_model,
211 | decoder,
212 | criterion,
213 | optimizer,
214 | device):
215 |
216 | losses = AverageMeter()
217 | decoder.train()
218 |
219 | context_size = 32 if config.tune_8k else 64
220 | num_samples = 256 if config.tune_8k else 512
221 | stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
222 | encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
223 |
224 | with torch.enable_grad():
225 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
226 | targets = targets.to(device)
227 | x = x.to(device)
228 | masks = masks.to(device)
229 | x = torch.nn.functional.pad(x, (context_size, 0))
230 |
231 | outs = []
232 | state = torch.zeros(0)
233 | for i in range(context_size, x.shape[1], num_samples):
234 | input_ = x[:, i-context_size:i+num_samples]
235 | out = stft_layer(input_)
236 | out = encoder_layer(out)
237 | out, state = decoder(out, state)
238 | outs.append(out)
239 | stacked = torch.cat(outs, dim=2).squeeze(1)
240 |
241 | loss = criterion(stacked, targets)
242 | loss = (loss * masks).mean()
243 | loss.backward()
244 | optimizer.step()
245 | losses.update(loss.item(), masks.numel())
246 |
247 | torch.cuda.empty_cache()
248 | gc.collect()
249 |
250 | return losses.avg
251 |
252 |
253 | def validate(config,
254 | loader,
255 | jit_model,
256 | decoder,
257 | criterion,
258 | device):
259 |
260 | losses = AverageMeter()
261 | decoder.eval()
262 |
263 | predicts = []
264 | gts = []
265 |
266 | context_size = 32 if config.tune_8k else 64
267 | num_samples = 256 if config.tune_8k else 512
268 | stft_layer = jit_model._model_8k.stft if config.tune_8k else jit_model._model.stft
269 | encoder_layer = jit_model._model_8k.encoder if config.tune_8k else jit_model._model.encoder
270 |
271 | with torch.no_grad():
272 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
273 | targets = targets.to(device)
274 | x = x.to(device)
275 | masks = masks.to(device)
276 | x = torch.nn.functional.pad(x, (context_size, 0))
277 |
278 | outs = []
279 | state = torch.zeros(0)
280 | for i in range(context_size, x.shape[1], num_samples):
281 | input_ = x[:, i-context_size:i+num_samples]
282 | out = stft_layer(input_)
283 | out = encoder_layer(out)
284 | out, state = decoder(out, state)
285 | outs.append(out)
286 | stacked = torch.cat(outs, dim=2).squeeze(1)
287 |
288 | predicts.extend(stacked[masks != 0].tolist())
289 | gts.extend(targets[masks != 0].tolist())
290 |
291 | loss = criterion(stacked, targets)
292 | loss = (loss * masks).mean()
293 | losses.update(loss.item(), masks.numel())
294 | score = roc_auc_score(gts, predicts)
295 |
296 | torch.cuda.empty_cache()
297 | gc.collect()
298 |
299 | return losses.avg, round(score, 3)
300 |
301 |
302 | def init_jit_model(model_path: str,
303 | device=torch.device('cpu')):
304 | torch.set_grad_enabled(False)
305 | model = torch.jit.load(model_path, map_location=device)
306 | model.eval()
307 | return model
308 |
309 |
310 | def predict(model, loader, device, sr):
311 | with torch.no_grad():
312 | all_predicts = []
313 | all_gts = []
314 | for _, (x, targets, masks) in tqdm(enumerate(loader), total=len(loader)):
315 | x = x.to(device)
316 | out = model.audio_forward(x, sr=sr)
317 |
318 | for i, out_chunk in enumerate(out):
319 | predict = out_chunk[masks[i] != 0].cpu().tolist()
320 | gt = targets[i, masks[i] != 0].cpu().tolist()
321 |
322 | all_predicts.append(predict)
323 | all_gts.append(gt)
324 | return all_predicts, all_gts
325 |
326 |
327 | def calculate_best_thresholds(all_predicts, all_gts):
328 | best_acc = 0
329 | for ths_enter in tqdm(np.linspace(0, 1, 20)):
330 | for ths_exit in np.linspace(0, 1, 20):
331 | if ths_exit >= ths_enter:
332 | continue
333 |
334 | accs = []
335 | for j, predict in enumerate(all_predicts):
336 | predict_bool = []
337 | is_speech = False
338 | for i in predict:
339 | if i >= ths_enter:
340 | is_speech = True
341 | predict_bool.append(1)
342 | elif i <= ths_exit:
343 | is_speech = False
344 | predict_bool.append(0)
345 | else:
346 | val = 1 if is_speech else 0
347 | predict_bool.append(val)
348 |
349 | score = round(accuracy_score(all_gts[j], predict_bool), 4)
350 | accs.append(score)
351 |
352 | mean_acc = round(np.mean(accs), 3)
353 | if mean_acc > best_acc:
354 | best_acc = mean_acc
355 | best_ths_enter = round(ths_enter, 2)
356 | best_ths_exit = round(ths_exit, 2)
357 | return best_ths_enter, best_ths_exit, best_acc
358 |
--------------------------------------------------------------------------------