├── VERSION ├── CPPLINT.cfg ├── MANIFEST.in ├── data ├── test_16k.wav └── test_with_noise_48k.wav ├── src ├── README.md ├── vad │ ├── CMakeLists.txt │ ├── onnx_model.h │ ├── vad_model.h │ ├── onnx_model.cc │ └── vad_model.cc ├── frontend │ ├── CMakeLists.txt │ ├── sample_queue.h │ ├── sample_queue.cc │ ├── resampler.h │ ├── denoiser.cc │ ├── denoiser.h │ ├── resampler.cc │ └── wav.h ├── toolchains │ ├── aarch64-linux-gnu.toolchain.cmake │ └── README.md ├── cmake │ ├── gflags.cmake │ ├── glog.cmake │ ├── samplerate.cmake │ ├── portaudio.cmake │ ├── rnnoise.cmake │ └── onnxruntime.cmake ├── bin │ ├── CMakeLists.txt │ ├── resample_main.cc │ ├── denoise_main.cc │ ├── vad_main.cc │ └── stream_vad_main.cc └── CMakeLists.txt ├── .flake8 ├── requirements.txt ├── README.md ├── .github └── workflows │ ├── black.yml │ ├── flake8.yml │ ├── isort.yml │ └── release.yml ├── .pre-commit-config.yaml ├── pysilero ├── __init__.py ├── utils.py ├── pickable_session.py ├── cli.py ├── frame_queue.py └── pysilero.py ├── LICENSE ├── setup.py └── .clang-format /VERSION: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /CPPLINT.cfg: -------------------------------------------------------------------------------- 1 | root=src 2 | filter=-build/c++11 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include VERSION 3 | -------------------------------------------------------------------------------- /data/test_16k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhendong/pysilero/HEAD/data/test_16k.wav -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | ## Runtime 2 | 3 | ``` bash 4 | $ export GLOG_logtostderr=1 5 | $ export GLOG_v=2 6 | ``` 7 | -------------------------------------------------------------------------------- /data/test_with_noise_48k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pengzhendong/pysilero/HEAD/data/test_with_noise_48k.wav -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | show-source = true 3 | statistics = true 4 | max-line-length = 120 5 | ignore = E203, E266, E501, W503 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audiolab 2 | librosa 3 | modelscope 4 | numpy 5 | onnxruntime 6 | soxr 7 | praat-parselmouth 8 | pyrnnoise 9 | -------------------------------------------------------------------------------- /src/vad/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(vad STATIC onnx_model.cc vad_model.cc) 2 | target_link_libraries(vad PUBLIC frontend glog onnxruntime) 3 | -------------------------------------------------------------------------------- /src/frontend/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(frontend STATIC denoiser.cc resampler.cc sample_queue.cc) 2 | target_link_libraries(frontend PUBLIC glog samplerate RnNoise) 3 | -------------------------------------------------------------------------------- /src/toolchains/aarch64-linux-gnu.toolchain.cmake: -------------------------------------------------------------------------------- 1 | set(CMAKE_SYSTEM_NAME Linux) 2 | set(CMAKE_SYSTEM_PROCESSOR aarch64) 3 | 4 | set(CMAKE_C_COMPILER aarch64-linux-gnu-gcc) 5 | set(CMAKE_CXX_COMPILER aarch64-linux-gnu-g++) 6 | -------------------------------------------------------------------------------- /src/cmake/gflags.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(gflags 2 | URL https://github.com/gflags/gflags/archive/v2.2.2.zip 3 | URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 4 | ) 5 | FetchContent_MakeAvailable(gflags) 6 | include_directories(${gflags_BINARY_DIR}/include) 7 | -------------------------------------------------------------------------------- /src/cmake/glog.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(glog 2 | URL https://github.com/google/glog/archive/v0.4.0.zip 3 | URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc 4 | ) 5 | FetchContent_MakeAvailable(glog) 6 | include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR}) 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PySilero 2 | 3 | ## Silero VAD 4 | 5 | See [Silero VAD](https://github.com/snakers4/silero-vad). 6 | 7 | ### Denoiser 8 | 9 | See [RnNoise](https://github.com/werman/noise-suppression-for-voice). 10 | 11 | ### Python Usage 12 | 13 | ```bash 14 | $ pip install pysilero 15 | $ pysilero audio.wav 16 | ``` 17 | -------------------------------------------------------------------------------- /src/cmake/samplerate.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(samplerate 2 | URL https://github.com/libsndfile/libsamplerate/archive/refs/tags/0.2.2.tar.gz 3 | URL_HASH SHA256=16e881487f184250deb4fcb60432d7556ab12cb58caea71ef23960aec6c0405a 4 | ) 5 | FetchContent_MakeAvailable(samplerate) 6 | include_directories(${libsamplerate_SOURCE_DIR}/include) 7 | -------------------------------------------------------------------------------- /src/cmake/portaudio.cmake: -------------------------------------------------------------------------------- 1 | set(PA_USE_JACK OFF CACHE BOOL "Enable support for JACK Audio Connection Kit" FORCE) 2 | FetchContent_Declare(portaudio 3 | URL https://github.com/PortAudio/portaudio/archive/refs/tags/v19.7.0.tar.gz 4 | URL_HASH SHA256=5af29ba58bbdbb7bbcefaaecc77ec8fc413f0db6f4c4e286c40c3e1b83174fa0 5 | ) 6 | FetchContent_MakeAvailable(portaudio) 7 | include_directories(${portaudio_SOURCE_DIR}/include) 8 | -------------------------------------------------------------------------------- /src/toolchains/README.md: -------------------------------------------------------------------------------- 1 | ## Cross Compile 2 | 3 | Install cross compile tools: 4 | 5 | ``` bash 6 | $ sudo apt-get install gcc-aarch64-linux-gnu g++-aarch64-linux-gnu 7 | ``` 8 | 9 | Or install the binaries from: https://releases.linaro.org/components/toolchain/binaries/latest-7 10 | 11 | 12 | ## Build 13 | 14 | ``` bash 15 | $ cmake -B build -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=toolchains/aarch64-linux-gnu.toolchain.cmake 16 | $ cmake --build build 17 | ``` 18 | -------------------------------------------------------------------------------- /src/bin/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(denoise_main denoise_main.cc) 2 | target_link_libraries(denoise_main PUBLIC frontend gflags) 3 | 4 | add_executable(resample_main resample_main.cc) 5 | target_link_libraries(resample_main PUBLIC frontend gflags) 6 | 7 | add_executable(vad_main vad_main.cc) 8 | target_link_libraries(vad_main PUBLIC vad gflags) 9 | 10 | add_executable(stream_vad_main stream_vad_main.cc) 11 | target_link_libraries(stream_vad_main PUBLIC vad gflags portaudio_static) 12 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: black 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | black: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - uses: psf/black@stable 18 | with: 19 | options: "--line-length=120 --check --diff --color" 20 | -------------------------------------------------------------------------------- /.github/workflows/flake8.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: flake8 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | flake8: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Install flake8 18 | run: | 19 | python -m pip install -U flake8 20 | - name: Lint with flake8 21 | run: | 22 | flake8 23 | -------------------------------------------------------------------------------- /.github/workflows/isort.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: isort 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | isort: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Install isort 18 | run: | 19 | python -m pip install -U isort 20 | - name: Check that imports are sorted 21 | run: | 22 | isort --check --profile=black --line-length=120 --diff . 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: main 4 | hooks: 5 | - id: black 6 | args: [--line-length=120] 7 | 8 | - repo: https://github.com/PyCQA/flake8 9 | rev: main 10 | hooks: 11 | - id: flake8 12 | args: [--config=.flake8, --ignore=E203] 13 | 14 | - repo: https://github.com/pycqa/isort 15 | rev: main 16 | hooks: 17 | - id: isort 18 | args: [--profile=black, --line-length=120] 19 | 20 | - repo: https://github.com/pre-commit/pre-commit-hooks 21 | rev: main 22 | hooks: 23 | - id: check-executables-have-shebangs 24 | - id: end-of-file-fixer 25 | - id: mixed-line-ending 26 | - id: trailing-whitespace 27 | -------------------------------------------------------------------------------- /src/vad/onnx_model.h: -------------------------------------------------------------------------------- 1 | #ifndef VAD_ONNX_MODEL_H_ 2 | #define VAD_ONNX_MODEL_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "onnxruntime_cxx_api.h" // NOLINT 8 | 9 | class OnnxModel { 10 | public: 11 | static void InitEngineThreads(int num_threads = 1); 12 | OnnxModel(const std::string& model_path); 13 | 14 | protected: 15 | static Ort::Env env_; 16 | static Ort::SessionOptions session_options_; 17 | 18 | std::shared_ptr session_ = nullptr; 19 | Ort::MemoryInfo memory_info_ = 20 | Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); 21 | 22 | std::vector input_node_names_; 23 | std::vector output_node_names_; 24 | }; 25 | 26 | #endif // VAD_ONNX_MODEL_H_ -------------------------------------------------------------------------------- /pysilero/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .pysilero import SileroVAD, VADIterator 16 | 17 | __all__ = ["SileroVAD", "VADIterator"] 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Create Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | version: 7 | description: "Build version (e.g. 0.0.1)" 8 | required: true 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Publish 16 | env: 17 | TWINE_USERNAME: __token__ 18 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 19 | run: | 20 | echo ${{ inputs.version }} > VERSION 21 | python -m venv .venv 22 | source .venv/bin/activate 23 | python -m pip install -U setuptools wheel 24 | python setup.py bdist_wheel 25 | python -m pip install -U twine 26 | python -m twine upload dist/* 27 | -------------------------------------------------------------------------------- /src/cmake/rnnoise.cmake: -------------------------------------------------------------------------------- 1 | option(BUILD_FOR_RELEASE "Additional optimizations and steps may be taken for release" ON) 2 | option(BUILD_TESTS "" OFF) 3 | option(BUILD_VST_PLUGIN "If the VST2 plugin should be built" OFF) 4 | option(BUILD_VST3_PLUGIN "If the VST3 plugin should be built" OFF) 5 | option(BUILD_LV2_PLUGIN "If the LV2 plugin should be built" OFF) 6 | option(BUILD_LADSPA_PLUGIN "If the LADSPA plugin should be built" OFF) 7 | option(BUILD_AU_PLUGIN "If the AU plugin should be built (macOS only)" OFF) 8 | option(BUILD_AUV3_PLUGIN "If the AUv3 plugin should be built (macOS only)" OFF) 9 | 10 | FetchContent_Declare(rnnoise 11 | URL https://github.com/werman/noise-suppression-for-voice/archive/refs/tags/v1.03.tar.gz 12 | URL_HASH SHA256=8c85cae3ebbb3a18facc38930a3b67ca90e3ad609526a0018c71690de35baf04 13 | ) 14 | FetchContent_MakeAvailable(rnnoise) 15 | include_directories(${rnnoise_SOURCE_DIR}/src/rnnoise/include) 16 | link_directories(${CMAKE_BINARY_DIR}/lib) 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-present Silero Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/frontend/sample_queue.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef FRONTEND_SAMPLE_QUEUE_H_ 16 | #define FRONTEND_SAMPLE_QUEUE_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | class SampleQueue { 23 | public: 24 | void AcceptWaveform(const std::vector& pcm); 25 | 26 | int NumSamples() const { return queue_.size(); } 27 | 28 | bool Read(int num_samples, std::vector* samples); 29 | 30 | void Clear(); 31 | 32 | private: 33 | std::queue queue_; 34 | }; 35 | 36 | #endif // FRONTEND_SAMPLE_QUEUE_H_ 37 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14 FATAL_ERROR) 2 | 3 | project(silero_vad VERSION 0.1) 4 | 5 | include(FetchContent) 6 | set(FETCHCONTENT_QUIET OFF) 7 | get_filename_component(fc_base 8 | "fc_base-${CMAKE_CXX_COMPILER_ID}" 9 | REALPATH BASE_DIR 10 | "${CMAKE_CURRENT_SOURCE_DIR}" 11 | ) 12 | set(FETCHCONTENT_BASE_DIR ${fc_base}) 13 | option(BUILD_TESTING "whether to build unit test" OFF) 14 | 15 | if(NOT MSVC) 16 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC") 17 | else() 18 | set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) 19 | add_compile_options(/W0 /wd4150 /wd4244 /wd4267) 20 | add_compile_options("$<$:/utf-8>") 21 | endif() 22 | 23 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) 24 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}) 25 | 26 | include(gflags) 27 | set(WITH_GFLAGS OFF CACHE BOOL "whether build glog with gflags" FORCE) 28 | include(samplerate) # Note: must include libsamplerate before glog 29 | include(glog) 30 | include(portaudio) 31 | include(rnnoise) 32 | include(onnxruntime) 33 | 34 | add_subdirectory(frontend) 35 | add_dependencies(frontend RnNoise) 36 | add_subdirectory(vad) 37 | add_subdirectory(bin) 38 | -------------------------------------------------------------------------------- /src/frontend/sample_queue.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "frontend/sample_queue.h" 16 | 17 | void SampleQueue::AcceptWaveform(const std::vector& pcm) { 18 | for (auto p : pcm) { 19 | queue_.push(p); 20 | } 21 | } 22 | 23 | bool SampleQueue::Read(int num_samples, std::vector* samples) { 24 | samples->clear(); 25 | if (queue_.size() >= num_samples) { 26 | for (int i = 0; i < num_samples; i++) { 27 | samples->emplace_back(queue_.front()); 28 | queue_.pop(); 29 | } 30 | return true; 31 | } else { 32 | return false; 33 | } 34 | } 35 | 36 | void SampleQueue::Clear() { queue_.empty(); } 37 | -------------------------------------------------------------------------------- /src/frontend/resampler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef FRONTEND_RESAMPLER_H_ 16 | #define FRONTEND_RESAMPLER_H_ 17 | 18 | #include 19 | #include 20 | 21 | #include "glog/logging.h" 22 | #include "samplerate.h" 23 | 24 | class Resampler { 25 | public: 26 | explicit Resampler(int in_sr, int out_sr, 27 | int converter = SRC_SINC_BEST_QUALITY); 28 | ~Resampler() { src_delete(src_state_); } 29 | 30 | void Reset() { src_reset(src_state_); } 31 | 32 | void Resample(const std::vector& in_pcm, std::vector* out_pcm, 33 | int enf_of_input = 0); 34 | 35 | private: 36 | float src_ratio_; 37 | SRC_STATE* src_state_; 38 | }; 39 | 40 | #endif // FRONTEND_RESAMPLER_H_ 41 | -------------------------------------------------------------------------------- /src/frontend/denoiser.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "frontend/denoiser.h" 16 | 17 | void Denoiser::Denoise(const std::vector& in_pcm, 18 | std::vector* out_pcm) { 19 | sample_queue_->AcceptWaveform(in_pcm); 20 | int num_frames = sample_queue_->NumSamples() / FRAME_SIZE; 21 | int num_out_samples = num_frames * FRAME_SIZE; 22 | 23 | std::vector input_pcm; 24 | sample_queue_->Read(num_out_samples, &input_pcm); 25 | out_pcm->resize(num_out_samples); 26 | 27 | for (int i = 0; i < num_frames; i++) { 28 | float* in_frames = input_pcm.data() + i * FRAME_SIZE; 29 | float* out_frames = out_pcm->data() + i * FRAME_SIZE; 30 | rnnoise_process_frame(st_, out_frames, in_frames); 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/frontend/denoiser.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef FRONTEND_DENOISER_H_ 16 | #define FRONTEND_DENOISER_H_ 17 | 18 | #include 19 | #include 20 | 21 | #include "glog/logging.h" 22 | #include "rnnoise/rnnoise.h" 23 | 24 | #include "frontend/sample_queue.h" 25 | 26 | #define FRAME_SIZE 480 // According to rnnoise/src/denoise.c 27 | 28 | class Denoiser { 29 | public: 30 | explicit Denoiser() { 31 | st_ = rnnoise_create(); 32 | sample_queue_ = std::make_shared(); 33 | }; 34 | ~Denoiser() { rnnoise_destroy(st_); }; 35 | 36 | void Reset() { sample_queue_->Clear(); }; 37 | 38 | void Denoise(const std::vector& in_pcm, std::vector* out_pcm); 39 | 40 | private: 41 | DenoiseState* st_ = nullptr; 42 | std::shared_ptr sample_queue_ = nullptr; 43 | }; 44 | 45 | #endif // FRONTEND_DENOISER_H_ 46 | -------------------------------------------------------------------------------- /src/vad/vad_model.h: -------------------------------------------------------------------------------- 1 | #ifndef VAD_VAD_MODEL_H_ 2 | #define VAD_VAD_MODEL_H_ 3 | 4 | #include "vad/onnx_model.h" 5 | 6 | #include 7 | 8 | #include "frontend/denoiser.h" 9 | #include "frontend/resampler.h" 10 | #include "frontend/sample_queue.h" 11 | 12 | #define SIZE_STATE 256 // 2 * 1 * 128 13 | 14 | class VadModel : public OnnxModel { 15 | public: 16 | VadModel(const std::string& model_path, bool denoise, int sample_rate, 17 | float threshold, int min_sil_dur_ms = 100, int speech_pad_ms = 30); 18 | 19 | void Reset(); 20 | 21 | void AcceptWaveform(const std::vector& pcm); 22 | void Vad(float* speech_start, float* speech_end, bool return_relative = false, 23 | bool return_seconds = false); 24 | 25 | private: 26 | float Forward(const std::vector& pcm); 27 | 28 | int frame_ms_ = 32; 29 | int frame_size_ = frame_ms_ * (16000 / 1000); 30 | 31 | bool denoise_ = false; 32 | int sample_rate_; 33 | float threshold_; 34 | int min_sil_dur_samples_; 35 | int speech_pad_samples_; 36 | 37 | // model states 38 | bool on_speech_ = false; 39 | float temp_end_ = 0; 40 | int current_sample_ = 0; 41 | 42 | // Onnx model 43 | std::vector state_; 44 | // std::vector context_; TODO: add context 45 | 46 | std::shared_ptr denoiser_ = nullptr; 47 | std::shared_ptr upsampler_ = nullptr; 48 | std::shared_ptr downsampler_ = nullptr; 49 | std::shared_ptr sample_queue_ = nullptr; 50 | }; 51 | 52 | #endif // VAD_VAD_MODEL_H_ 53 | -------------------------------------------------------------------------------- /pysilero/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | import librosa 18 | import numpy as np 19 | import parselmouth 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | def get_energy(chunk, sr, from_harmonic=1, to_harmonic=5): 25 | sound = parselmouth.Sound(chunk, sampling_frequency=sr) 26 | # pitch 27 | pitch = sound.to_pitch(pitch_floor=100, pitch_ceiling=350) 28 | # pitch energy 29 | # energy = np.mean(pitch.selected_array["strength"]) 30 | pitch = np.mean(pitch.selected_array["frequency"]) 31 | # frame log energy 32 | # energy = np.mean(sound.to_mfcc().to_array(), axis=1)[0] 33 | 34 | # energy form x-th harmonic to y-th harmonic 35 | freqs = librosa.fft_frequencies(sr=sr) 36 | freq_band_idx = np.where((freqs >= from_harmonic * pitch) & (freqs <= to_harmonic * pitch))[0] 37 | energy = np.sum(np.abs(librosa.stft(chunk)[freq_band_idx, :])) 38 | 39 | return energy 40 | -------------------------------------------------------------------------------- /src/frontend/resampler.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "frontend/resampler.h" 16 | 17 | Resampler::Resampler(int in_sr, int out_sr, int converter) { 18 | src_state_ = src_new(converter, 1, nullptr); 19 | src_ratio_ = out_sr * 1.0 / in_sr; 20 | src_set_ratio(src_state_, src_ratio_); 21 | } 22 | 23 | void Resampler::Resample(const std::vector& in_pcm, 24 | std::vector* out_pcm, int end_of_input) { 25 | out_pcm->resize(in_pcm.size() * src_ratio_); 26 | 27 | SRC_DATA src_data; 28 | src_data.src_ratio = src_ratio_; 29 | src_data.end_of_input = end_of_input; 30 | src_data.data_in = in_pcm.data(); 31 | src_data.input_frames = in_pcm.size(); 32 | src_data.data_out = out_pcm->data(); 33 | src_data.output_frames = out_pcm->size(); 34 | 35 | int error = src_process(src_state_, &src_data); 36 | if (error != 0) { 37 | LOG(FATAL) << "src_process error: " << src_strerror(error); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import find_packages, setup 16 | 17 | with open("requirements.txt", encoding="utf8") as f: 18 | requirements = f.readlines() 19 | 20 | setup( 21 | name="pysilero", 22 | version=open("VERSION", encoding="utf8").read(), 23 | author="Zhendong Peng", 24 | author_email="pzd17@tsinghua.org.cn", 25 | long_description=open("README.md", encoding="utf8").read(), 26 | long_description_content_type="text/markdown", 27 | url="https://github.com/pengzhendong/pysilero", 28 | packages=find_packages(), 29 | package_data={ 30 | "pysilero": ["*.onnx"], 31 | }, 32 | install_requires=requirements, 33 | entry_points={ 34 | "console_scripts": [ 35 | "pysilero = pysilero.cli:main", 36 | ] 37 | }, 38 | classifiers=[ 39 | "Programming Language :: Python :: 3", 40 | "Operating System :: OS Independent", 41 | "Topic :: Scientific/Engineering", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /src/vad/onnx_model.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "vad/onnx_model.h" 4 | 5 | #include "glog/logging.h" 6 | 7 | Ort::Env OnnxModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); 8 | Ort::SessionOptions OnnxModel::session_options_ = Ort::SessionOptions(); 9 | 10 | void OnnxModel::InitEngineThreads(int num_threads) { 11 | session_options_.SetIntraOpNumThreads(num_threads); 12 | session_options_.SetGraphOptimizationLevel( 13 | GraphOptimizationLevel::ORT_ENABLE_ALL); 14 | } 15 | 16 | static std::wstring ToWString(const std::string& str) { 17 | unsigned len = str.size() * 2; 18 | setlocale(LC_CTYPE, ""); 19 | wchar_t* p = new wchar_t[len]; 20 | mbstowcs(p, str.c_str(), len); 21 | std::wstring wstr(p); 22 | delete[] p; 23 | return wstr; 24 | } 25 | 26 | OnnxModel::OnnxModel(const std::string& model_path) { 27 | InitEngineThreads(1); 28 | #ifdef _MSC_VER 29 | session_ = std::make_shared(env_, ToWString(model_path).c_str(), 30 | session_options_); 31 | #else 32 | session_ = std::make_shared(env_, model_path.c_str(), 33 | session_options_); 34 | #endif 35 | Ort::AllocatorWithDefaultOptions allocator; 36 | // Input info 37 | int num_nodes = session_->GetInputCount(); 38 | input_node_names_.resize(num_nodes); 39 | for (int i = 0; i < num_nodes; ++i) { 40 | input_node_names_[i] = session_->GetInputName(i, allocator); 41 | LOG(INFO) << "Input names[" << i << "]: " << input_node_names_[i]; 42 | } 43 | // Output info 44 | num_nodes = session_->GetOutputCount(); 45 | output_node_names_.resize(num_nodes); 46 | for (int i = 0; i < num_nodes; ++i) { 47 | output_node_names_[i] = session_->GetOutputName(i, allocator); 48 | LOG(INFO) << "Output names[" << i << "]: " << output_node_names_[i]; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/bin/resample_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | 17 | #include "gflags/gflags.h" 18 | 19 | #include "frontend/resampler.h" 20 | #include "frontend/wav.h" 21 | 22 | DEFINE_string(input_wav_path, "", "input wav path"); 23 | DEFINE_int32(output_sample_rate, 16000, "output sample rate"); 24 | DEFINE_string(output_wav_path, "", "output wav path"); 25 | 26 | int main(int argc, char* argv[]) { 27 | gflags::ParseCommandLineFlags(&argc, &argv, false); 28 | google::InitGoogleLogging(argv[0]); 29 | 30 | wav::WavReader wav_reader(FLAGS_input_wav_path); 31 | int num_channels = wav_reader.num_channels(); 32 | CHECK_EQ(num_channels, 1) << "Only support mono (1 channel) wav!"; 33 | int sample_rate = wav_reader.sample_rate(); 34 | const float* pcm = wav_reader.data(); 35 | int num_samples = wav_reader.num_samples(); 36 | std::vector input_pcm{pcm, pcm + num_samples}; 37 | 38 | std::vector output_pcm; 39 | Resampler resampler(sample_rate, FLAGS_output_sample_rate); 40 | resampler.Resample(input_pcm, &output_pcm, true); 41 | 42 | wav::WavWriter writer(output_pcm.data(), output_pcm.size(), num_channels, 43 | FLAGS_output_sample_rate, 16); 44 | writer.Write(FLAGS_output_wav_path); 45 | } 46 | -------------------------------------------------------------------------------- /src/cmake/onnxruntime.cmake: -------------------------------------------------------------------------------- 1 | if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") 2 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.0/onnxruntime-win-x64-1.12.0.zip") 3 | set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176") 4 | elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") 5 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") 6 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.0/onnxruntime-linux-aarch64-1.12.0.tgz") 7 | set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac") 8 | else() 9 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.0/onnxruntime-linux-x64-1.12.0.tgz") 10 | set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5") 11 | endif() 12 | elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") 13 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") 14 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.0/onnxruntime-osx-arm64-1.12.0.tgz") 15 | set(URL_HASH "SHA256=23117b6f5d7324d4a7c51184e5f808dd952aec411a6b99a1b6fd1011de06e300") 16 | else() 17 | set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.12.0/onnxruntime-osx-x86_64-1.12.0.tgz") 18 | set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600") 19 | endif() 20 | else() 21 | message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')") 22 | endif() 23 | 24 | FetchContent_Declare(onnxruntime 25 | URL ${ONNX_URL} 26 | URL_HASH ${URL_HASH} 27 | ) 28 | FetchContent_MakeAvailable(onnxruntime) 29 | include_directories(${onnxruntime_SOURCE_DIR}/include) 30 | link_directories(${onnxruntime_SOURCE_DIR}/lib) 31 | 32 | if(MSVC) 33 | file(GLOB ONNX_DLLS "${onnxruntime_SOURCE_DIR}/lib/*.dll") 34 | if(CMAKE_BUILD_TYPE) 35 | file(COPY ${ONNX_DLLS} DESTINATION ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE}) 36 | else() 37 | file(COPY ${ONNX_DLLS} DESTINATION ${CMAKE_BINARY_DIR}/${CMAKE_BUILD_TYPE_INIT}) 38 | endif() 39 | endif() 40 | -------------------------------------------------------------------------------- /pysilero/pickable_session.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | 17 | import onnxruntime as ort 18 | from modelscope import snapshot_download 19 | 20 | 21 | class PickableSession: 22 | """ 23 | This is a wrapper to make the current InferenceSession class pickable. 24 | """ 25 | 26 | def __init__(self, version="v5"): 27 | opts = ort.SessionOptions() 28 | opts.inter_op_num_threads = 1 29 | opts.intra_op_num_threads = 1 30 | opts.log_severity_level = 3 31 | 32 | assert version in ["v4", "v5"] 33 | model_id = "pengzhendong/silero-vad" 34 | try: 35 | repo_dir = snapshot_download(model_id) 36 | except Exception: 37 | from modelscope.utils.file_utils import get_default_modelscope_cache_dir 38 | 39 | repo_dir = f"{get_default_modelscope_cache_dir()}/models/{model_id}" 40 | self.model_path = f"{repo_dir}/{version}/silero_vad.onnx" 41 | self.init_session = partial(ort.InferenceSession, sess_options=opts, providers=["CPUExecutionProvider"]) 42 | self.sess = self.init_session(self.model_path) 43 | 44 | def run(self, *args): 45 | return self.sess.run(None, *args) 46 | 47 | def __getstate__(self): 48 | return {"model_path": self.model_path} 49 | 50 | def __setstate__(self, values): 51 | self.model_path = values["model_path"] 52 | self.sess = self.init_session(self.model_path) 53 | 54 | 55 | VERSIONS = ["v4", "v5"] 56 | silero_vad = {version: PickableSession(version) for version in VERSIONS} 57 | -------------------------------------------------------------------------------- /src/bin/denoise_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | 17 | #include "gflags/gflags.h" 18 | 19 | #include "frontend/denoiser.h" 20 | #include "frontend/resampler.h" 21 | #include "frontend/wav.h" 22 | 23 | DEFINE_string(input_wav_path, "", "input wav path"); 24 | DEFINE_string(output_wav_path, "", "output wav path"); 25 | 26 | int main(int argc, char* argv[]) { 27 | gflags::ParseCommandLineFlags(&argc, &argv, false); 28 | google::InitGoogleLogging(argv[0]); 29 | 30 | wav::WavReader wav_reader(FLAGS_input_wav_path); 31 | int num_channels = wav_reader.num_channels(); 32 | CHECK_EQ(num_channels, 1) << "Only support mono (1 channel) wav!"; 33 | int sample_rate = wav_reader.sample_rate(); 34 | const float* pcm = wav_reader.data(); 35 | int num_samples = wav_reader.num_samples(); 36 | std::vector input_pcm{pcm, pcm + num_samples}; 37 | 38 | std::vector output_pcm; 39 | // 0. Upsample to 48k for RnNoise 40 | if (sample_rate != 48000) { 41 | Resampler upsampler(sample_rate, 48000); 42 | upsampler.Resample(input_pcm, &output_pcm, true); 43 | input_pcm = output_pcm; 44 | } 45 | // 1. Denoise with RnNoise 46 | Denoiser denoiser; 47 | denoiser.Denoise(input_pcm, &output_pcm); 48 | // 2. Downsample back to original sample rate 49 | if (sample_rate != 48000) { 50 | input_pcm = output_pcm; 51 | Resampler downsampler(48000, sample_rate); 52 | downsampler.Resample(input_pcm, &output_pcm, true); 53 | } 54 | 55 | wav::WavWriter writer(output_pcm.data(), output_pcm.size(), num_channels, 56 | sample_rate, 16); 57 | writer.Write(FLAGS_output_wav_path); 58 | } 59 | -------------------------------------------------------------------------------- /src/bin/vad_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | 18 | #include "gflags/gflags.h" 19 | 20 | #include "frontend/wav.h" 21 | #include "vad/vad_model.h" 22 | 23 | DEFINE_string(wav_path, "", "wav path"); 24 | DEFINE_double(threshold, 0.5, "threshold of voice activity detection"); 25 | DEFINE_string(model_path, "", "voice activity detection model path"); 26 | 27 | int main(int argc, char* argv[]) { 28 | gflags::ParseCommandLineFlags(&argc, &argv, false); 29 | google::InitGoogleLogging(argv[0]); 30 | 31 | wav::WavReader wav_reader(FLAGS_wav_path); 32 | int num_channels = wav_reader.num_channels(); 33 | CHECK_EQ(num_channels, 1) << "Only support mono (1 channel) wav!"; 34 | int sample_rate = wav_reader.sample_rate(); 35 | const float* pcm = wav_reader.data(); 36 | int num_samples = wav_reader.num_samples(); 37 | 38 | const int frame_size_ms = 10; 39 | const int frame_size_samples = frame_size_ms * sample_rate / 1000; 40 | VadModel vad(FLAGS_model_path, true, sample_rate, FLAGS_threshold); 41 | 42 | for (int i = 0; i < num_samples; i += frame_size_samples) { 43 | // Extract 10ms frame from input_pcm 44 | int remaining_samples = std::min(frame_size_samples, num_samples - i); 45 | std::vector input_pcm(pcm + i, pcm + i + remaining_samples); 46 | vad.AcceptWaveform(input_pcm); 47 | float speech_start = -1; 48 | float speech_end = -1; 49 | vad.Vad(&speech_start, &speech_end, false, true); 50 | if (speech_start >= 0) { 51 | std::cout << "[" << speech_start << ", "; 52 | } 53 | if (speech_end >= 0) { 54 | std::cout << speech_end << "]" << std::endl; 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /pysilero/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import click 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | from audiolab import Reader, Writer, info, load_audio 19 | 20 | from pysilero import SileroVAD, VADIterator 21 | 22 | 23 | @click.command() 24 | @click.argument("wav_path", type=click.Path(exists=True, file_okay=True)) 25 | @click.option("--version", default="v5", help="Silero VAD version") 26 | @click.option("--denoise/--no-denoise", default=False, help="Denoise before vad") 27 | @click.option("--streaming/--no-streaming", default=False, help="Streming mode") 28 | @click.option("--save-path", help="Save path for output audio") 29 | @click.option("--plot/--no-plot", default=False, help="Plot the vad probabilities") 30 | def main(wav_path, version, denoise, streaming, save_path, plot): 31 | if not streaming: 32 | model = SileroVAD(version, info(wav_path).rate, denoise=denoise) 33 | speech_timestamps = model.get_speech_timestamps(wav_path, return_seconds=True, save_path=save_path) 34 | print("None streaming result:", list(speech_timestamps)) 35 | 36 | if plot: 37 | audio, rate = load_audio(wav_path, dtype=np.float32) 38 | x1 = np.arange(0, audio.shape[1]) / rate 39 | outputs = list(model.get_speech_probs(wav_path)) 40 | x2 = [i * 32 / 1000 for i in range(0, len(outputs))] 41 | plt.plot(x1, audio[0]) 42 | plt.plot(x2, outputs) 43 | plt.show() 44 | else: 45 | print("Streaming result:", end=" ") 46 | reader = Reader(wav_path, dtype=np.float32, frame_size_ms=10) 47 | if save_path is not None: 48 | writer = Writer(save_path, reader.rate, layout=reader.layout) 49 | vad_iterator = VADIterator(version, reader.rate) 50 | for idx, (frame, _) in enumerate(reader): 51 | partial = idx == reader.num_frames - 1 52 | for speech_dict, speech_samples in vad_iterator(frame.squeeze(), partial, return_seconds=True): 53 | if "start" in speech_dict or "end" in speech_dict: 54 | print(speech_dict, end=" ") 55 | if save_path is not None and speech_samples is not None: 56 | writer.write(speech_samples) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlinesLeft: true 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: All 15 | AllowShortIfStatementsOnASingleLine: true 16 | AllowShortLoopsOnASingleLine: true 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: true 20 | AlwaysBreakTemplateDeclarations: true 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: false 25 | AfterControlStatement: false 26 | AfterEnum: false 27 | AfterFunction: false 28 | AfterNamespace: false 29 | AfterObjCDeclaration: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | BeforeCatch: false 33 | BeforeElse: false 34 | IndentBraces: false 35 | BreakBeforeBinaryOperators: None 36 | BreakBeforeBraces: Attach 37 | BreakBeforeTernaryOperators: true 38 | BreakConstructorInitializersBeforeComma: false 39 | BreakAfterJavaFieldAnnotations: false 40 | BreakStringLiterals: true 41 | ColumnLimit: 80 42 | CommentPragmas: '^ IWYU pragma:' 43 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 44 | ConstructorInitializerIndentWidth: 4 45 | ContinuationIndentWidth: 4 46 | Cpp11BracedListStyle: true 47 | DisableFormat: false 48 | ExperimentalAutoDetectBinPacking: false 49 | ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] 50 | IncludeCategories: 51 | - Regex: '^<.*\.h>' 52 | Priority: 1 53 | - Regex: '^<.*' 54 | Priority: 2 55 | - Regex: '.*' 56 | Priority: 3 57 | IncludeIsMainRegex: '([-_](test|unittest))?$' 58 | IndentCaseLabels: true 59 | IndentWidth: 2 60 | IndentWrappedFunctionNames: false 61 | JavaScriptQuotes: Leave 62 | JavaScriptWrapImports: true 63 | KeepEmptyLinesAtTheStartOfBlocks: false 64 | MacroBlockBegin: '' 65 | MacroBlockEnd: '' 66 | MaxEmptyLinesToKeep: 1 67 | NamespaceIndentation: None 68 | ObjCBlockIndentWidth: 2 69 | ObjCSpaceAfterProperty: false 70 | ObjCSpaceBeforeProtocolList: false 71 | PenaltyBreakBeforeFirstCallParameter: 1 72 | PenaltyBreakComment: 300 73 | PenaltyBreakFirstLessLess: 120 74 | PenaltyBreakString: 1000 75 | PenaltyExcessCharacter: 1000000 76 | PenaltyReturnTypeOnItsOwnLine: 200 77 | PointerAlignment: Left 78 | ReflowComments: true 79 | SortIncludes: true 80 | SpaceAfterCStyleCast: false 81 | SpaceBeforeAssignmentOperators: true 82 | SpaceBeforeParens: ControlStatements 83 | SpaceInEmptyParentheses: false 84 | SpacesBeforeTrailingComments: 2 85 | SpacesInAngles: false 86 | SpacesInContainerLiterals: true 87 | SpacesInCStyleCastParentheses: false 88 | SpacesInParentheses: false 89 | SpacesInSquareBrackets: false 90 | Standard: Auto 91 | TabWidth: 8 92 | UseTab: Never 93 | ... 94 | -------------------------------------------------------------------------------- /src/bin/stream_vad_main.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2024 Zhendong Peng (pzd17@tsinghua.org.cn) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | #include "gflags/gflags.h" 21 | #include "portaudio.h" 22 | 23 | #include "vad/vad_model.h" 24 | 25 | DEFINE_double(threshold, 0.5, "threshold of voice activity detection"); 26 | DEFINE_string(model_path, "", "voice activity detection model path"); 27 | 28 | int g_exiting = 0; 29 | int sample_rate = 16000; 30 | std::shared_ptr vad; 31 | 32 | void SigRoutine(int dunno) { 33 | if (dunno == SIGINT) { 34 | g_exiting = 1; 35 | } 36 | } 37 | 38 | static int RecordCallback(const void* input, void* output, 39 | unsigned long frames_count, 40 | const PaStreamCallbackTimeInfo* time_info, 41 | PaStreamCallbackFlags status_flags, void* user_data) { 42 | const auto* pcm_data = static_cast(input); 43 | std::vector pcm(pcm_data, pcm_data + frames_count); 44 | vad->AcceptWaveform(pcm); 45 | if (g_exiting) { 46 | LOG(INFO) << "Exiting loop."; 47 | return paComplete; 48 | } else { 49 | return paContinue; 50 | } 51 | } 52 | 53 | int main(int argc, char* argv[]) { 54 | gflags::ParseCommandLineFlags(&argc, &argv, false); 55 | google::InitGoogleLogging(argv[0]); 56 | vad = std::make_shared(FLAGS_model_path, true, sample_rate, 57 | FLAGS_threshold); 58 | 59 | signal(SIGINT, SigRoutine); 60 | PaError err = Pa_Initialize(); 61 | PaStreamParameters params; 62 | std::cout << err << " " << Pa_GetDeviceCount() << std::endl; 63 | params.device = Pa_GetDefaultInputDevice(); 64 | if (params.device == paNoDevice) { 65 | LOG(FATAL) << "Error: No default input device."; 66 | } 67 | params.channelCount = 1; 68 | params.sampleFormat = paInt16; 69 | params.suggestedLatency = 70 | Pa_GetDeviceInfo(params.device)->defaultLowInputLatency; 71 | params.hostApiSpecificStreamInfo = NULL; 72 | PaStream* stream; 73 | // Callback and process pcm date each `interval` ms. 74 | int interval_ms = 10; 75 | const int frame_size_samples = interval_ms * sample_rate / 1000; 76 | Pa_OpenStream(&stream, ¶ms, NULL, sample_rate, frame_size_samples, 77 | paClipOff, RecordCallback, NULL); 78 | Pa_StartStream(stream); 79 | LOG(INFO) << "=== Now recording!! Please speak into the microphone. ==="; 80 | 81 | while (Pa_IsStreamActive(stream)) { 82 | float speech_start = -1; 83 | float speech_end = -1; 84 | vad->Vad(&speech_start, &speech_end, false, true); 85 | if (speech_start >= 0) { 86 | LOG(INFO) << "start: " << speech_start; 87 | } 88 | if (speech_end >= 0) { 89 | LOG(INFO) << "end: " << speech_end; 90 | } 91 | } 92 | Pa_StopStream(stream); 93 | Pa_CloseStream(stream); 94 | Pa_Terminate(); 95 | } 96 | -------------------------------------------------------------------------------- /pysilero/frame_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import soxr 17 | 18 | 19 | class FrameQueue: 20 | def __init__(self, frame_size, in_rate, speech_pad_samples=0, out_rate=None, padding=True): 21 | self.frame_size = frame_size 22 | # padding zeros for the last frame 23 | self.padding = padding 24 | self.speech_pad_samples = speech_pad_samples 25 | # cache the original samples for padding and soxr's delay 26 | # TODO: use the largest delay of soxr instead of 500ms cache 27 | num_cached_samples = speech_pad_samples + 500 * in_rate // 1000 28 | self.cached_samples = np.zeros(num_cached_samples, dtype=np.float32) 29 | self.cache_start = -len(self.cached_samples) 30 | 31 | self.current_sample = 0 32 | self.remained_samples = np.empty(0, dtype=np.float32) 33 | 34 | if out_rate is None or in_rate == out_rate: 35 | self.step = 1.0 36 | self.resampler = None 37 | else: 38 | self.step = in_rate / out_rate 39 | self.resampler = soxr.ResampleStream(in_rate, out_rate, num_channels=1) 40 | 41 | def add_chunk(self, chunk, is_last=False): 42 | # cache the original frame without resampling for `lookforward` of vad start 43 | # cache start is the absolute sample index of the first sample in the cached_samples 44 | if len(chunk) > 0: 45 | self.cache_start += len(chunk) 46 | self.cached_samples = np.roll(self.cached_samples, -len(chunk)) 47 | self.cached_samples[-len(chunk) :] = chunk[-len(self.cached_samples) :] 48 | # resample 49 | if self.resampler is not None: 50 | chunk = self.resampler.resample_chunk(chunk, is_last) 51 | # enqueue chunk 52 | self.remained_samples = np.concatenate((self.remained_samples, chunk)) 53 | 54 | while len(self.remained_samples) >= self.frame_size: 55 | frame = self.remained_samples[: self.frame_size] 56 | self.remained_samples = self.remained_samples[self.frame_size :] 57 | # frame_start and frame_end is the sample index before resampling 58 | frame_start = self.current_sample 59 | self.current_sample += int(len(frame) * self.step) 60 | frame_end = self.current_sample 61 | yield frame_start, frame_end, frame 62 | 63 | if is_last and len(self.remained_samples) > 0 and self.padding: 64 | frame = self.remained_samples 65 | frame_start = self.current_sample 66 | self.current_sample += int(len(frame) * self.step) 67 | frame = np.pad(frame, (0, self.frame_size - len(frame))) 68 | frame_end = self.current_sample 69 | yield frame_start, frame_end, frame 70 | 71 | def get_frame(self, speech_padding=False): 72 | # dequeue one original frame without resampling 73 | frame_start = self.current_sample - int(self.frame_size * self.step) 74 | frame_end = self.current_sample 75 | if speech_padding: 76 | frame_start -= self.speech_pad_samples 77 | # get the relative sample index of the speech 78 | speech_start = frame_start - self.cache_start 79 | speech_end = frame_end - self.cache_start 80 | return self.cached_samples[speech_start:speech_end] 81 | 82 | 83 | if __name__ == "__main__": 84 | queue = FrameQueue(3, 1000) 85 | frames = [[1, 2, 3], [4, 5], [6, 7, 8]] 86 | for index, frame in enumerate(frames): 87 | for frame_start, frame_end, frame in queue.add_chunk(frame, index == len(frames) - 1): 88 | print(frame_start, frame_end, frame) 89 | -------------------------------------------------------------------------------- /src/vad/vad_model.cc: -------------------------------------------------------------------------------- 1 | #include "vad/vad_model.h" 2 | 3 | #include 4 | #include 5 | 6 | #include "glog/logging.h" 7 | 8 | VadModel::VadModel(const std::string& model_path, bool denoise, int sample_rate, 9 | float threshold, int min_sil_dur_ms, int speech_pad_ms) 10 | : OnnxModel(model_path), 11 | denoise_(denoise), 12 | sample_rate_(sample_rate), 13 | threshold_(threshold), 14 | min_sil_dur_samples_(min_sil_dur_ms * sample_rate / 1000), 15 | speech_pad_samples_(speech_pad_ms * sample_rate / 1000) { 16 | denoiser_ = std::make_shared(); 17 | sample_queue_ = std::make_shared(); 18 | if (denoise) { 19 | if (sample_rate != 48000) { 20 | upsampler_ = std::make_shared(sample_rate, 48000); 21 | if (sample_rate == 8000) { 22 | downsampler_ = std::make_shared(48000, 8000); 23 | } else { 24 | downsampler_ = std::make_shared(48000, 16000); 25 | } 26 | } 27 | } else if (sample_rate != 16000 || sample_rate != 8000) { 28 | downsampler_ = std::make_shared(sample_rate, 16000); 29 | } 30 | 31 | Reset(); 32 | } 33 | 34 | void VadModel::Reset() { 35 | state_.resize(SIZE_STATE); 36 | std::memset(state_.data(), 0.0f, SIZE_STATE * sizeof(float)); 37 | on_speech_ = false; 38 | temp_end_ = 0; 39 | current_sample_ = 0; 40 | sample_queue_->Clear(); 41 | denoiser_->Reset(); 42 | if (upsampler_) { 43 | upsampler_->Reset(); 44 | } 45 | if (downsampler_) { 46 | downsampler_->Reset(); 47 | } 48 | } 49 | 50 | float VadModel::Forward(const std::vector& pcm) { 51 | std::vector input_pcm{pcm.data(), pcm.data() + pcm.size()}; 52 | for (int i = 0; i < input_pcm.size(); i++) { 53 | input_pcm[i] /= 32768.0; 54 | } 55 | 56 | // batch_size * num_samples 57 | const int64_t batch_size = 1; 58 | int64_t input_node_dims[2] = {batch_size, 59 | static_cast(input_pcm.size())}; 60 | auto input_ort = Ort::Value::CreateTensor( 61 | memory_info_, input_pcm.data(), input_pcm.size(), input_node_dims, 2); 62 | 63 | const int64_t sr_node_dims[1] = {batch_size}; 64 | std::vector sr = {sample_rate_}; 65 | auto sr_ort = Ort::Value::CreateTensor(memory_info_, sr.data(), 66 | batch_size, sr_node_dims, 1); 67 | const int64_t state_node_dims[3] = {2, batch_size, 128}; 68 | auto state_ort = Ort::Value::CreateTensor( 69 | memory_info_, state_.data(), SIZE_STATE, state_node_dims, 3); 70 | 71 | std::vector ort_inputs; 72 | ort_inputs.emplace_back(std::move(input_ort)); 73 | ort_inputs.emplace_back(std::move(state_ort)); 74 | ort_inputs.emplace_back(std::move(sr_ort)); 75 | 76 | auto ort_outputs = session_->Run( 77 | Ort::RunOptions{nullptr}, input_node_names_.data(), ort_inputs.data(), 78 | ort_inputs.size(), output_node_names_.data(), output_node_names_.size()); 79 | 80 | float posterier = ort_outputs[0].GetTensorMutableData()[0]; 81 | float* state = ort_outputs[1].GetTensorMutableData(); 82 | state_.assign(state, state + SIZE_STATE); 83 | 84 | return posterier; 85 | } 86 | 87 | void VadModel::AcceptWaveform(const std::vector& pcm) { 88 | std::vector in_pcm{pcm.data(), pcm.data() + pcm.size()}; 89 | std::vector resampled_pcm; 90 | if (denoise_) { 91 | std::vector denoised_pcm; 92 | // 0. Upsample to 48k for RnNoise 93 | if (upsampler_) { 94 | upsampler_->Resample(in_pcm, &resampled_pcm); 95 | in_pcm = resampled_pcm; 96 | } 97 | // 1. Denoise with RnNoise 98 | denoiser_->Denoise(in_pcm, &denoised_pcm); 99 | in_pcm = denoised_pcm; 100 | } 101 | // 2. Downsample to 16k for VAD 102 | if (downsampler_) { 103 | downsampler_->Resample(in_pcm, &resampled_pcm); 104 | sample_rate_ = 16000; 105 | in_pcm = resampled_pcm; 106 | } 107 | sample_queue_->AcceptWaveform(in_pcm); 108 | } 109 | 110 | void VadModel::Vad(float* speech_start, float* speech_end, bool return_relative, 111 | bool return_seconds) { 112 | std::vector in_pcm; 113 | int num_frames = sample_queue_->NumSamples() / frame_size_; 114 | if (num_frames > 0) { 115 | sample_queue_->Read(frame_size_, &in_pcm); 116 | int window_size_samples = in_pcm.size(); 117 | current_sample_ += window_size_samples; 118 | float speech_prob = Forward(in_pcm); 119 | 120 | // 1. start 121 | if (speech_prob >= threshold_) { 122 | temp_end_ = 0; 123 | if (on_speech_ == false) { 124 | on_speech_ = true; 125 | *speech_start = 126 | current_sample_ - window_size_samples - speech_pad_samples_; 127 | if (return_relative) { 128 | *speech_start = current_sample_ - *speech_start; 129 | } 130 | if (return_seconds) { 131 | *speech_start = round(*speech_start / sample_rate_ * 1000) / 1000; 132 | } 133 | } 134 | } 135 | // 2. stop 136 | if (speech_prob < (threshold_ - 0.15) && on_speech_ == true) { 137 | if (temp_end_ == 0) { 138 | temp_end_ = current_sample_; 139 | } 140 | // hangover 141 | if (current_sample_ - temp_end_ >= min_sil_dur_samples_) { 142 | *speech_end = temp_end_ + speech_pad_samples_; 143 | if (return_relative) { 144 | *speech_end = current_sample_ - *speech_end; 145 | } 146 | if (return_seconds) { 147 | *speech_end = round(*speech_end / sample_rate_ * 1000) / 1000; 148 | } 149 | temp_end_ = 0; 150 | on_speech_ = false; 151 | } 152 | } 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/frontend/wav.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2016 Personal (Binbin Zhang) 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef WAV_H_ 16 | #define WAV_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include 25 | 26 | #include "glog/logging.h" 27 | 28 | namespace wav { 29 | 30 | struct WavHeader { 31 | char riff[4] = {'R', 'I', 'F', 'F'}; 32 | unsigned int size = 0; 33 | char wav[4] = {'W', 'A', 'V', 'E'}; 34 | char fmt[4] = {'f', 'm', 't', ' '}; 35 | unsigned int fmt_size = 16; 36 | uint16_t format = 1; 37 | uint16_t channels = 0; 38 | unsigned int sample_rate = 0; 39 | unsigned int bytes_per_second = 0; 40 | uint16_t block_size = 0; 41 | uint16_t bit = 0; 42 | char data[4] = {'d', 'a', 't', 'a'}; 43 | unsigned int data_size = 0; 44 | 45 | WavHeader() {} 46 | 47 | WavHeader(int num_samples, int num_channels, int sample_rate, 48 | int bits_per_sample) { 49 | data_size = num_samples * num_channels * (bits_per_sample / 8); 50 | size = sizeof(WavHeader) - 8 + data_size; 51 | channels = num_channels; 52 | this->sample_rate = sample_rate; 53 | bytes_per_second = sample_rate * num_channels * (bits_per_sample / 8); 54 | block_size = num_channels * (bits_per_sample / 8); 55 | bit = bits_per_sample; 56 | } 57 | }; 58 | 59 | class WavReader { 60 | public: 61 | WavReader() : data_(nullptr) {} 62 | explicit WavReader(const std::string& filename) { Open(filename); } 63 | 64 | bool Open(const std::string& filename) { 65 | FILE* fp = fopen(filename.c_str(), "rb"); 66 | if (NULL == fp) { 67 | LOG(WARNING) << "Error in read " << filename; 68 | return false; 69 | } 70 | 71 | WavHeader header; 72 | fread(&header, 1, sizeof(header), fp); 73 | if (header.fmt_size < 16) { 74 | fprintf(stderr, 75 | "WaveData: expect PCM format data " 76 | "to have fmt chunk of at least size 16.\n"); 77 | return false; 78 | } else if (header.fmt_size > 16) { 79 | int offset = 44 - 8 + header.fmt_size - 16; 80 | fseek(fp, offset, SEEK_SET); 81 | fread(header.data, 8, sizeof(char), fp); 82 | } 83 | // check "RIFF" "WAVE" "fmt " "data" 84 | 85 | // Skip any sub-chunks between "fmt" and "data". Usually there will 86 | // be a single "fact" sub chunk, but on Windows there can also be a 87 | // "list" sub chunk. 88 | while (0 != strncmp(header.data, "data", 4)) { 89 | // We will just ignore the data in these chunks. 90 | fseek(fp, header.data_size, SEEK_CUR); 91 | // read next sub chunk 92 | fread(header.data, 8, sizeof(char), fp); 93 | } 94 | 95 | num_channels_ = header.channels; 96 | sample_rate_ = header.sample_rate; 97 | bits_per_sample_ = header.bit; 98 | int num_data = header.data_size / (bits_per_sample_ / 8); 99 | data_ = new float[num_data]; 100 | num_samples_ = num_data / num_channels_; 101 | 102 | for (int i = 0; i < num_data; ++i) { 103 | switch (bits_per_sample_) { 104 | case 8: { 105 | char sample; 106 | fread(&sample, 1, sizeof(char), fp); 107 | data_[i] = static_cast(sample); 108 | break; 109 | } 110 | case 16: { 111 | int16_t sample; 112 | fread(&sample, 1, sizeof(int16_t), fp); 113 | data_[i] = static_cast(sample); 114 | break; 115 | } 116 | case 32: { 117 | int sample; 118 | fread(&sample, 1, sizeof(int), fp); 119 | data_[i] = static_cast(sample); 120 | break; 121 | } 122 | default: 123 | fprintf(stderr, "unsupported quantization bits"); 124 | exit(1); 125 | } 126 | } 127 | fclose(fp); 128 | return true; 129 | } 130 | 131 | int num_channels() const { return num_channels_; } 132 | int sample_rate() const { return sample_rate_; } 133 | int bits_per_sample() const { return bits_per_sample_; } 134 | int num_samples() const { return num_samples_; } 135 | 136 | ~WavReader() { delete[] data_; } 137 | 138 | const float* data() const { return data_; } 139 | 140 | private: 141 | int num_channels_; 142 | int sample_rate_; 143 | int bits_per_sample_; 144 | int num_samples_; // sample points per channel 145 | float* data_; 146 | }; 147 | 148 | class WavWriter { 149 | public: 150 | WavWriter(const float* data, int num_samples, int num_channels, 151 | int sample_rate, int bits_per_sample) 152 | : data_(data), 153 | num_samples_(num_samples), 154 | num_channels_(num_channels), 155 | sample_rate_(sample_rate), 156 | bits_per_sample_(bits_per_sample) {} 157 | 158 | void Write(const std::string& filename) { 159 | FILE* fp = fopen(filename.c_str(), "wb"); 160 | WavHeader header(num_samples_, num_channels_, sample_rate_, 161 | bits_per_sample_); 162 | fwrite(&header, 1, sizeof(header), fp); 163 | 164 | for (int i = 0; i < num_samples_; ++i) { 165 | for (int j = 0; j < num_channels_; ++j) { 166 | switch (bits_per_sample_) { 167 | case 8: { 168 | char sample = static_cast(data_[i * num_channels_ + j]); 169 | fwrite(&sample, 1, sizeof(sample), fp); 170 | break; 171 | } 172 | case 16: { 173 | int16_t sample = static_cast(data_[i * num_channels_ + j]); 174 | fwrite(&sample, 1, sizeof(sample), fp); 175 | break; 176 | } 177 | case 32: { 178 | int sample = static_cast(data_[i * num_channels_ + j]); 179 | fwrite(&sample, 1, sizeof(sample), fp); 180 | break; 181 | } 182 | } 183 | } 184 | } 185 | fclose(fp); 186 | } 187 | 188 | private: 189 | const float* data_; 190 | int num_samples_; // total float points in data_ 191 | int num_channels_; 192 | int sample_rate_; 193 | int bits_per_sample_; 194 | }; 195 | 196 | class StreamWavWriter { 197 | public: 198 | StreamWavWriter(int num_channels, int sample_rate, int bits_per_sample) 199 | : num_channels_(num_channels), 200 | sample_rate_(sample_rate), 201 | bits_per_sample_(bits_per_sample), 202 | total_num_samples_(0) {} 203 | 204 | StreamWavWriter(const std::string& filename, int num_channels, 205 | int sample_rate, int bits_per_sample) 206 | : StreamWavWriter(num_channels, sample_rate, bits_per_sample) { 207 | Open(filename); 208 | } 209 | 210 | void Open(const std::string& filename) { 211 | fp_ = fopen(filename.c_str(), "wb"); 212 | fseek(fp_, sizeof(WavHeader), SEEK_SET); 213 | } 214 | 215 | void Write(const int16_t* sample_data, size_t num_samples) { 216 | fwrite(sample_data, sizeof(int16_t), num_samples, fp_); 217 | total_num_samples_ += num_samples; 218 | } 219 | 220 | void Close() { 221 | WavHeader header(total_num_samples_, num_channels_, sample_rate_, 222 | bits_per_sample_); 223 | fseek(fp_, 0L, SEEK_SET); 224 | fwrite(&header, 1, sizeof(header), fp_); 225 | fclose(fp_); 226 | } 227 | 228 | private: 229 | FILE* fp_; 230 | int num_channels_; 231 | int sample_rate_; 232 | int bits_per_sample_; 233 | size_t total_num_samples_; 234 | }; 235 | 236 | } // namespace wav 237 | 238 | #endif // WAV_H_ 239 | -------------------------------------------------------------------------------- /pysilero/pysilero.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Zhendong Peng (pzd17@tsinghua.org.cn) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from functools import partial 17 | from pathlib import Path 18 | from typing import Union 19 | 20 | import numpy as np 21 | from audiolab import info, load_audio, save_audio 22 | from frame_queue import FrameQueue 23 | from pickable_session import silero_vad 24 | from pyrnnoise import RNNoise 25 | from tqdm import tqdm 26 | from utils import get_energy 27 | 28 | 29 | class SileroVAD: 30 | def __init__( 31 | self, 32 | version: str = "v5", 33 | sample_rate: int = 16000, 34 | threshold: float = 0.5, 35 | min_silence_duration_ms: int = 300, 36 | speech_pad_ms: int = 100, 37 | denoise: bool = False, 38 | ): 39 | """ 40 | Init silero VAD model 41 | 42 | Parameters 43 | ---------- 44 | version: str (default - v5) 45 | silero-vad version (v4 or v5) 46 | sample_rate: int (default - 16000) 47 | sample rate of the input audio 48 | threshold: float (default - 0.5) 49 | Speech threshold. Silero VAD outputs speech probabilities for each audio 50 | chunk, probabilities ABOVE this value are considered as SPEECH. It is 51 | better to tune this parameter for each dataset separately, but "lazy" 52 | 0.5 is pretty good for most datasets. 53 | min_silence_duration_ms: int (default - 300 milliseconds) 54 | In the end of each speech chunk wait for min_silence_duration_ms before 55 | separating it. 56 | speech_pad_ms: int (default - 100 milliseconds) 57 | Final speech chunks are padded by speech_pad_ms each side. 58 | denoise: bool (default - False) 59 | whether denoise the audio samples. 60 | """ 61 | self.version = version 62 | self.session = silero_vad[version] 63 | self.threshold = threshold 64 | self.sample_rate = sample_rate 65 | 66 | self.speech_pad_samples = speech_pad_ms * sample_rate // 1000 67 | self.min_silence_samples = min_silence_duration_ms * sample_rate // 1000 68 | self.model_sample_rate = sample_rate if sample_rate in [8000, 16000] else 16000 69 | 70 | if self.version == "v4": 71 | self.h = np.zeros((2, 1, 64), dtype=np.float32) 72 | self.c = np.zeros((2, 1, 64), dtype=np.float32) 73 | else: 74 | self.state = np.zeros((2, 1, 128), dtype=np.float32) 75 | self.context_size = 64 if self.model_sample_rate == 16000 else 32 76 | self.context = np.zeros((1, self.context_size), dtype=np.float32) 77 | 78 | self.num_samples = 512 if self.model_sample_rate == 16000 else 256 79 | self.queue = FrameQueue( 80 | self.num_samples, 81 | self.sample_rate, 82 | self.speech_pad_samples, 83 | out_rate=self.model_sample_rate, 84 | ) 85 | 86 | self.segment = 0 87 | self.denoiser = RNNoise(sample_rate) if denoise else None 88 | 89 | def reset(self): 90 | self.segment = 0 91 | self.queue = FrameQueue( 92 | self.num_samples, 93 | self.sample_rate, 94 | self.speech_pad_samples, 95 | out_rate=self.model_sample_rate, 96 | ) 97 | if self.version == "v4": 98 | self.h = np.zeros((2, 1, 64), dtype=np.float32) 99 | self.c = np.zeros((2, 1, 64), dtype=np.float32) 100 | else: 101 | self.state = np.zeros((2, 1, 128), dtype=np.float32) 102 | self.context = np.zeros((1, self.context_size), dtype=np.float32) 103 | 104 | def __call__(self, x, sr): 105 | if self.version == "v4": 106 | x = x[np.newaxis, :] 107 | sr = np.array(sr, dtype=np.int64) 108 | ort_inputs = {"input": x, "h": self.h, "c": self.c, "sr": sr} 109 | output, self.h, self.c = self.session.run(ort_inputs) 110 | else: 111 | x = np.concatenate((self.context, x[np.newaxis, :]), axis=1) 112 | self.context = x[:, -self.context_size :] 113 | ort_inputs = {"input": x, "state": self.state, "sr": np.array(sr, dtype=np.int64)} 114 | output, self.state = self.session.run(ort_inputs) 115 | return output 116 | 117 | @staticmethod 118 | def denoise_chunk(denoiser, chunk, is_last=False): 119 | frames = [] 120 | for _, frame in denoiser.process_chunk(chunk, is_last): 121 | frames.append(frame.squeeze()) 122 | return np.concatenate(frames) if len(frames) > 0 else np.array([]) 123 | 124 | def add_chunk(self, chunk, is_last=False): 125 | if self.denoiser is not None: 126 | chunk = self.denoise_chunk(self.denoiser, chunk, is_last) 127 | return self.queue.add_chunk(chunk, is_last) 128 | 129 | def get_speech_probs(self, wav_path: Union[str, Path]): 130 | """ 131 | Getting speech probabilities of audio frames (32ms/frame) 132 | 133 | Parameters 134 | ---------- 135 | wav_path: wav path 136 | 137 | Returns 138 | ---------- 139 | speech_probs: list of speech probabilities 140 | """ 141 | self.reset() 142 | audio, _ = load_audio(wav_path, dtype=np.float32, rate=self.sample_rate, to_mono=True) 143 | progress_bar = tqdm( 144 | total=math.ceil(info(wav_path).duration / 0.032), 145 | desc="VAD processing", 146 | unit="frames", 147 | bar_format="{l_bar}{bar}{r_bar} | {percentage:.2f}%", 148 | ) 149 | for _, _, frame in self.add_chunk(audio[0], True): 150 | progress_bar.update(1) 151 | yield np.around(self(frame, self.model_sample_rate)[0][0], 2) 152 | 153 | def process_segment(self, segment, wav, save_path, flat_layout, return_seconds): 154 | index = segment["segment"] 155 | start = max(segment["start"] - self.speech_pad_samples, 0) 156 | end = min(segment["end"] + self.speech_pad_samples, len(wav)) 157 | if save_path is not None: 158 | wav = wav[start:end] 159 | if self.denoiser is not None: 160 | # Initial denoiser for each segments 161 | denoiser = RNNoise(self.sample_rate) 162 | wav = self.denoise_chunk(denoiser, wav, True) 163 | if flat_layout: 164 | save_audio(str(save_path) + f"_{index:05d}.wav", wav[np.newaxis, :], self.sample_rate) 165 | else: 166 | save_path = Path(save_path) 167 | if not save_path.exists(): 168 | save_path.mkdir(parents=True, exist_ok=True) 169 | save_audio(str(save_path / f"{index:05d}.wav"), wav[np.newaxis, :], self.sample_rate) 170 | if return_seconds: 171 | start = round(start / self.sample_rate, 3) 172 | end = round(end / self.sample_rate, 3) 173 | return {"segment": index, "start": start, "end": end} 174 | 175 | def get_speech_timestamps( 176 | self, 177 | wav_path: Union[str, Path], 178 | save_path: Union[str, Path] = None, 179 | flat_layout: bool = True, 180 | min_speech_duration_ms: int = 250, 181 | max_speech_duration_s: float = float("inf"), 182 | return_seconds: bool = False, 183 | ): 184 | """ 185 | Splitting long audios into speech chunks using silero VAD 186 | 187 | Parameters 188 | ---------- 189 | wav_path: wav path 190 | save_path: string or Path (default - None) 191 | whether the save speech segments 192 | flat_layout: bool (default - True) 193 | whether use the flat directory structure 194 | min_speech_duration_ms: int (default - 250 milliseconds) 195 | Final speech chunks shorter min_speech_duration_ms are thrown out 196 | max_speech_duration_s: int (default - inf) 197 | Maximum duration of speech chunks in seconds 198 | Chunks longer than max_speech_duration_s will be split at the timestamp 199 | of the last silence that lasts more than 98ms (if any), to prevent 200 | agressive cutting. Otherwise, they will be split aggressively just 201 | before max_speech_duration_s. 202 | return_seconds: bool (default - False) 203 | whether return timestamps in seconds (default - samples) 204 | 205 | Returns 206 | ---------- 207 | speeches: list of dicts 208 | list containing ends and beginnings of speech chunks (samples or seconds 209 | based on return_seconds) 210 | """ 211 | self.reset() 212 | audio, sample_rate = load_audio(wav_path, dtype=np.float32, rate=self.sample_rate, to_mono=True) 213 | progress_bar = tqdm( 214 | total=math.ceil(info(wav_path).duration / 0.032), 215 | desc="VAD processing", 216 | unit="frames", 217 | bar_format="{l_bar}{bar}{r_bar} | {percentage:.2f}%", 218 | ) 219 | 220 | min_silence_samples_at_max_speech = 98 * sample_rate // 1000 221 | min_speech_samples = min_speech_duration_ms * sample_rate // 1000 222 | max_speech_duration_samples = max_speech_duration_s * sample_rate 223 | max_speech_samples = max_speech_duration_samples - 2 * self.speech_pad_samples 224 | 225 | fn = partial( 226 | self.process_segment, 227 | wav=audio[0], 228 | save_path=save_path, 229 | flat_layout=flat_layout, 230 | return_seconds=return_seconds, 231 | ) 232 | 233 | current_speech = {} 234 | neg_threshold = self.threshold - 0.15 235 | triggered = False 236 | # to save potential segment end (and tolerate some silence) 237 | temp_end = 0 238 | # to save potential segment limits in case of maximum segment size reached 239 | prev_end = 0 240 | next_start = 0 241 | for frame_start, frame_end, frame in self.add_chunk(audio[0], True): 242 | progress_bar.update(1) 243 | speech_prob = self(frame, self.model_sample_rate) 244 | # current frame is speech 245 | if speech_prob >= self.threshold: 246 | if temp_end > 0 and next_start < prev_end: 247 | next_start = frame_end 248 | temp_end = 0 249 | if not triggered: 250 | triggered = True 251 | current_speech["start"] = frame_end 252 | continue 253 | # in speech, and speech duration is more than max speech duration 254 | if triggered and frame_start - current_speech["start"] > max_speech_samples: 255 | # prev_end larger than 0 means there is a short silence in the middle avoid aggressive cutting 256 | if prev_end > 0: 257 | current_speech["end"] = prev_end 258 | current_speech["segment"] = self.segment 259 | self.segment += 1 260 | yield fn(current_speech) 261 | current_speech = {} 262 | # previously reached silence (< neg_thres) and is still not speech (< thres) 263 | if next_start < prev_end: 264 | triggered = False 265 | else: 266 | current_speech["start"] = next_start 267 | prev_end = 0 268 | next_start = 0 269 | temp_end = 0 270 | else: 271 | current_speech["end"] = frame_end 272 | current_speech["segment"] = self.segment 273 | self.segment += 1 274 | yield fn(current_speech) 275 | current_speech = {} 276 | prev_end = 0 277 | next_start = 0 278 | temp_end = 0 279 | triggered = False 280 | continue 281 | # in speech, and current frame is silence 282 | if triggered and speech_prob < neg_threshold: 283 | if temp_end == 0: 284 | temp_end = frame_end 285 | # record the last silence before reaching max speech duration 286 | if frame_end - temp_end > min_silence_samples_at_max_speech: 287 | prev_end = temp_end 288 | if frame_end - temp_end >= self.min_silence_samples: 289 | current_speech["end"] = temp_end 290 | # keep the speech segment if it is longer than min_speech_samples 291 | if current_speech["end"] - current_speech["start"] > min_speech_samples: 292 | current_speech["segment"] = self.segment 293 | self.segment += 1 294 | yield fn(current_speech) 295 | 296 | current_speech = {} 297 | prev_end = 0 298 | next_start = 0 299 | temp_end = 0 300 | triggered = False 301 | 302 | # deal with the last speech segment 303 | if current_speech and len(audio) - current_speech["start"] > min_speech_samples: 304 | current_speech["end"] = len(audio) 305 | current_speech["segment"] = self.segment 306 | yield fn(current_speech) 307 | 308 | 309 | class VADIterator(SileroVAD): 310 | def __init__( 311 | self, 312 | version: str = "v5", 313 | sample_rate: int = 16000, 314 | threshold: float = 0.5, 315 | min_silence_duration_ms: int = 300, 316 | speech_pad_ms: int = 100, 317 | denoise: bool = False, 318 | ): 319 | """ 320 | Class for stream imitation 321 | """ 322 | super().__init__( 323 | version, 324 | sample_rate, 325 | threshold, 326 | min_silence_duration_ms, 327 | speech_pad_ms, 328 | denoise, 329 | ) 330 | self.segment = 0 331 | self.temp_end = 0 332 | self.triggered = False 333 | # for offline asr 334 | self.speech_samples = np.empty(0, dtype=np.float32) 335 | self.reset() 336 | 337 | def reset(self): 338 | super().reset() 339 | self.segment = 0 340 | self.temp_end = 0 341 | self.triggered = False 342 | self.speech_samples = np.empty(0, dtype=np.float32) 343 | 344 | def get_frame(self, speech_padding=False): 345 | frame = self.queue.get_frame(speech_padding) 346 | if speech_padding: 347 | self.speech_samples = np.empty(0, dtype=np.float32) 348 | self.speech_samples = np.concatenate((self.speech_samples, frame)) 349 | return frame 350 | 351 | def __call__(self, chunk, is_last=False, use_energy=False, return_seconds=False): 352 | """ 353 | chunk: audio chunk 354 | 355 | is_last: bool (default - False) 356 | whether is the last audio chunk 357 | use_energy: bool (default - False) 358 | whether to use harmonic energy to suppress background vocals 359 | return_seconds: bool (default - False) 360 | whether return timestamps in seconds (default - samples) 361 | """ 362 | for frame_start, frame_end, frame in self.add_chunk(chunk, is_last): 363 | speech_prob = super().__call__(frame, self.model_sample_rate) 364 | # Suppress background vocals by harmonic energy 365 | if use_energy: 366 | energy = get_energy(frame, self.model_sample_rate, from_harmonic=4) 367 | if speech_prob < 0.9 and energy < 500 * (1 - speech_prob): 368 | speech_prob = 0 369 | 370 | is_start = False 371 | if speech_prob >= self.threshold: 372 | self.temp_end = 0 373 | # triggered = True means the speech has been started 374 | if not self.triggered: 375 | is_start = True 376 | self.triggered = True 377 | speech_start = max(frame_start - self.speech_pad_samples, 0) 378 | if return_seconds: 379 | speech_start = round(speech_start / self.sample_rate, 3) 380 | yield {"start": speech_start}, self.get_frame(True) 381 | elif speech_prob < self.threshold - 0.15 and self.triggered: 382 | if not self.temp_end: 383 | self.temp_end = frame_end 384 | if frame_end - self.temp_end >= self.min_silence_samples: 385 | speech_end = self.temp_end + self.speech_pad_samples 386 | if return_seconds: 387 | speech_end = round(speech_end / self.sample_rate, 3) 388 | self.temp_end = 0 389 | self.triggered = False 390 | yield {"end": speech_end, "segment": self.segment}, self.get_frame() 391 | self.segment += 1 392 | if not is_start and self.triggered: 393 | yield {}, self.get_frame() 394 | 395 | if is_last and self.triggered: 396 | speech_end = self.queue.current_sample 397 | if return_seconds: 398 | speech_end = round(speech_end / self.sample_rate, 3) 399 | yield {"end": speech_end, "segment": self.segment}, self.get_frame() 400 | self.reset() 401 | --------------------------------------------------------------------------------