├── .github ├── FUNDING.yml └── workflows │ ├── lint.yml │ ├── pytest.yml │ ├── python-publish.yml │ └── quick-runs.yml ├── .gitignore ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── models │ ├── embedding_uint8.onnx │ └── segmentation_uint8.onnx ├── demo.gif ├── docs ├── Makefile ├── _static │ └── logo.png ├── conf.py ├── index.rst ├── make.bat └── requirements.txt ├── environment.yml ├── expected_outputs ├── offline │ ├── vbx+overlap_aware_segmentation │ │ ├── AMI.rttm │ │ ├── DIHARD2.rttm │ │ ├── DIHARD3.rttm │ │ └── VoxConverse.rttm │ └── vbx │ │ ├── AMI.rttm │ │ ├── DIHARD2.rttm │ │ ├── DIHARD3.rttm │ │ └── VoxConverse.rttm └── online │ ├── 0.5s │ ├── AMI.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 1.0s │ ├── AMI.rttm │ ├── DIHARD2_reoptimized.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 2.0s │ ├── AMI.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 3.0s │ ├── AMI.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 4.0s │ ├── AMI.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 5.0s+oracle_segmentation │ ├── AMI.rttm │ ├── DIHARD2.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ ├── 5.0s-overlap_aware_embeddings │ ├── AMI.rttm │ ├── DIHARD2.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm │ └── 5.0s │ ├── AMI.rttm │ ├── DIHARD2.rttm │ ├── DIHARD3.rttm │ └── VoxConverse.rttm ├── figure1.png ├── figure5.png ├── logo.jpg ├── paper.pdf ├── pipeline.gif ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── diart │ ├── __init__.py │ ├── argdoc.py │ ├── audio.py │ ├── blocks │ ├── __init__.py │ ├── aggregation.py │ ├── base.py │ ├── clustering.py │ ├── diarization.py │ ├── embedding.py │ ├── segmentation.py │ ├── utils.py │ └── vad.py │ ├── console │ ├── __init__.py │ ├── benchmark.py │ ├── client.py │ ├── serve.py │ ├── stream.py │ └── tune.py │ ├── features.py │ ├── functional.py │ ├── inference.py │ ├── mapping.py │ ├── models.py │ ├── operators.py │ ├── optim.py │ ├── progress.py │ ├── sinks.py │ ├── sources.py │ └── utils.py ├── table1.png └── tests ├── conftest.py ├── data ├── audio │ └── sample.wav └── rttm │ ├── latency_0.5.rttm │ ├── latency_1.rttm │ ├── latency_2.rttm │ ├── latency_3.rttm │ ├── latency_4.rttm │ └── latency_5.rttm ├── test_aggregation.py ├── test_diarization.py ├── test_end_to_end.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [juanmc2005] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | # patreon: # Replace with a single Patreon username 5 | # open_collective: # Replace with a single Open Collective username 6 | # ko_fi: # Replace with a single Ko-fi username 7 | # tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | # community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | # liberapay: # Replace with a single Liberapay username 10 | # issuehunt: # Replace with a single IssueHunt username 11 | # otechie: # Replace with a single Otechie username 12 | # lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | # custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - develop 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: psf/black@stable 15 | with: 16 | options: "--check --verbose" 17 | src: "./src/diart" 18 | version: "23.10.1" 19 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Pytest 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - develop 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | strategy: 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - name: Checkout code 19 | uses: actions/checkout@v3 20 | 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install apt dependencies 27 | run: | 28 | sudo apt-get update 29 | sudo apt-get -y install ffmpeg libportaudio2 30 | 31 | - name: Install pip dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install .[tests] 35 | 36 | - name: Run tests 37 | run: | 38 | pytest 39 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/quick-runs.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Quick Runs 5 | 6 | on: 7 | pull_request: 8 | branches: [ "main", "develop" ] 9 | 10 | jobs: 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | python-version: ["3.10", "3.11", "3.12"] 18 | 19 | steps: 20 | - uses: actions/checkout@v3 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v3 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Download data 26 | run: | 27 | mkdir audio rttms trash 28 | wget --no-verbose --show-progress --continue -O audio/ES2002a_long.wav http://groups.inf.ed.ac.uk/ami/AMICorpusMirror/amicorpus/ES2002a/audio/ES2002a.Mix-Headset.wav 29 | wget --no-verbose --show-progress --continue -O audio/ES2002b_long.wav http://groups.inf.ed.ac.uk/ami/AMICorpusMirror/amicorpus/ES2002b/audio/ES2002b.Mix-Headset.wav 30 | wget --no-verbose --show-progress --continue -O rttms/ES2002a_long.rttm https://raw.githubusercontent.com/pyannote/AMI-diarization-setup/main/only_words/rttms/train/ES2002a.rttm 31 | wget --no-verbose --show-progress --continue -O rttms/ES2002b_long.rttm https://raw.githubusercontent.com/pyannote/AMI-diarization-setup/main/only_words/rttms/train/ES2002b.rttm 32 | - name: Install apt dependencies 33 | run: | 34 | sudo apt-get update 35 | sudo apt-get -y install ffmpeg libportaudio2 sox 36 | - name: Install pip dependencies 37 | run: | 38 | python -m pip install --upgrade pip 39 | pip install . 40 | pip install onnxruntime==1.18.0 41 | - name: Crop audio and rttm 42 | run: | 43 | sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30 44 | sox audio/ES2002b_long.wav audio/ES2002b.wav trim 00:10 00:30 45 | head -n 4 rttms/ES2002a_long.rttm > rttms/ES2002a.rttm 46 | head -n 7 rttms/ES2002b_long.rttm > rttms/ES2002b.rttm 47 | rm audio/ES2002a_long.wav 48 | rm audio/ES2002b_long.wav 49 | rm rttms/ES2002a_long.rttm 50 | rm rttms/ES2002b_long.rttm 51 | - name: Run stream 52 | run: | 53 | diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot 54 | - name: Run benchmark 55 | run: | 56 | diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx 57 | - name: Run tuning 58 | run: | 59 | diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all 4 | 5 | ### PyCharm+all ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # AWS User-specific 17 | .idea/**/aws.xml 18 | 19 | # Generated files 20 | .idea/**/contentModel.xml 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/artifacts 40 | # .idea/compiler.xml 41 | # .idea/jarRepositories.xml 42 | # .idea/modules.xml 43 | # .idea/*.iml 44 | # .idea/modules 45 | # *.iml 46 | # *.ipr 47 | 48 | # CMake 49 | cmake-build-*/ 50 | 51 | # Mongo Explorer plugin 52 | .idea/**/mongoSettings.xml 53 | 54 | # File-based project format 55 | *.iws 56 | 57 | # IntelliJ 58 | out/ 59 | 60 | # mpeltonen/sbt-idea plugin 61 | .idea_modules/ 62 | 63 | # JIRA plugin 64 | atlassian-ide-plugin.xml 65 | 66 | # Cursive Clojure plugin 67 | .idea/replstate.xml 68 | 69 | # Crashlytics plugin (for Android Studio and IntelliJ) 70 | com_crashlytics_export_strings.xml 71 | crashlytics.properties 72 | crashlytics-build.properties 73 | fabric.properties 74 | 75 | # Editor-based Rest Client 76 | .idea/httpRequests 77 | 78 | # Android studio 3.1+ serialized cache file 79 | .idea/caches/build_file_checksums.ser 80 | 81 | ### PyCharm+all Patch ### 82 | # Ignores the whole .idea folder and all .iml files 83 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 84 | 85 | .idea/ 86 | 87 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 88 | 89 | *.iml 90 | modules.xml 91 | .idea/misc.xml 92 | *.ipr 93 | 94 | # Sonarlint plugin 95 | .idea/sonarlint 96 | 97 | ### Python ### 98 | # Byte-compiled / optimized / DLL files 99 | __pycache__/ 100 | *.py[cod] 101 | *$py.class 102 | 103 | # C extensions 104 | *.so 105 | 106 | # Distribution / packaging 107 | .Python 108 | build/ 109 | develop-eggs/ 110 | dist/ 111 | downloads/ 112 | eggs/ 113 | .eggs/ 114 | lib/ 115 | lib64/ 116 | parts/ 117 | sdist/ 118 | var/ 119 | wheels/ 120 | share/python-wheels/ 121 | *.egg-info/ 122 | .installed.cfg 123 | *.egg 124 | MANIFEST 125 | 126 | # PyInstaller 127 | # Usually these files are written by a python script from a template 128 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 129 | *.manifest 130 | *.spec 131 | 132 | # Installer logs 133 | pip-log.txt 134 | pip-delete-this-directory.txt 135 | 136 | # Unit test / coverage reports 137 | htmlcov/ 138 | .tox/ 139 | .nox/ 140 | .coverage 141 | .coverage.* 142 | .cache 143 | nosetests.xml 144 | coverage.xml 145 | *.cover 146 | *.py,cover 147 | .hypothesis/ 148 | .pytest_cache/ 149 | cover/ 150 | 151 | # Translations 152 | *.mo 153 | *.pot 154 | 155 | # Django stuff: 156 | *.log 157 | local_settings.py 158 | db.sqlite3 159 | db.sqlite3-journal 160 | 161 | # Flask stuff: 162 | instance/ 163 | .webassets-cache 164 | 165 | # Scrapy stuff: 166 | .scrapy 167 | 168 | # Sphinx documentation 169 | docs/_build/ 170 | 171 | # PyBuilder 172 | .pybuilder/ 173 | target/ 174 | 175 | # Jupyter Notebook 176 | .ipynb_checkpoints 177 | 178 | # IPython 179 | profile_default/ 180 | ipython_config.py 181 | 182 | # pyenv 183 | # For a library or package, you might want to ignore these files since the code is 184 | # intended to run in multiple environments; otherwise, check them in: 185 | # .python-version 186 | 187 | # pipenv 188 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 189 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 190 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 191 | # install all needed dependencies. 192 | #Pipfile.lock 193 | 194 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 195 | __pypackages__/ 196 | 197 | # Celery stuff 198 | celerybeat-schedule 199 | celerybeat.pid 200 | 201 | # SageMath parsed files 202 | *.sage.py 203 | 204 | # Environments 205 | .env 206 | .venv 207 | env/ 208 | venv/ 209 | ENV/ 210 | env.bak/ 211 | venv.bak/ 212 | 213 | # Spyder project settings 214 | .spyderproject 215 | .spyproject 216 | 217 | # Rope project settings 218 | .ropeproject 219 | 220 | # mkdocs documentation 221 | /site 222 | 223 | # mypy 224 | .mypy_cache/ 225 | .dmypy.json 226 | dmypy.json 227 | 228 | # Pyre type checker 229 | .pyre/ 230 | 231 | # pytype static type analyzer 232 | .pytype/ 233 | 234 | # Cython debug symbols 235 | cython_debug/ 236 | 237 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all 238 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | # Install diart before building the docs 12 | - method: pip 13 | path: . 14 | 15 | sphinx: 16 | configuration: docs/conf.py -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to diart 2 | 3 | Thank you for considering contributing to diart! We appreciate your time and effort to help make this project better. 4 | 5 | ## Before You Start 6 | 7 | 1. **Search for Existing Issues or Discussions:** 8 | - Before opening a new issue or discussion, please check if there's already an existing one related to your topic. This helps avoid duplicates and keeps discussions centralized. 9 | 10 | 2. **Discuss Your Contribution:** 11 | - If you plan to make a significant change, it's advisable to discuss it in an issue first. This ensures that your contribution aligns with the project's goals and avoids duplicated efforts. 12 | 13 | 3. **Questions about diart:** 14 | - For general questions about diart, use the discussion space on GitHub. This helps in fostering a collaborative environment and encourages knowledge-sharing. 15 | 16 | ## Opening Issues 17 | 18 | If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue: 19 | 20 | - **Bug Reports:** 21 | - Clearly describe the error, including any relevant stack traces. 22 | - Provide a minimal, reproducible example that demonstrates the issue. 23 | - Mention the version of diart you are using (as well as any dependencies related to the bug). 24 | 25 | - **Feature Requests:** 26 | - Clearly outline the new feature you are proposing. 27 | - Explain how it would benefit the project. 28 | 29 | ## Opening Pull Requests 30 | 31 | We welcome and appreciate contributions! To ensure a smooth review process, please follow these guidelines when opening a pull request: 32 | 33 | - **Create a Branch:** 34 | - Work on your changes in a dedicated branch created from `develop`. 35 | 36 | - **Commit Messages:** 37 | - Write clear and concise commit messages, explaining the purpose of each change. 38 | 39 | - **Documentation:** 40 | - Update documentation when introducing new features or making changes that impact existing functionality. 41 | 42 | - **Tests:** 43 | - If applicable, add or update tests to cover your changes. 44 | 45 | - **Code Style:** 46 | - Follow the existing coding style of the project. We use `black` and `isort`. 47 | 48 | - **Discuss Before Major Changes:** 49 | - If your PR includes significant changes, discuss it in an issue first. 50 | 51 | - **Follow the existing workflow:** 52 | - Make sure to open your PR against `develop` (**not** `main`). 53 | 54 | ## Thank You 55 | 56 | Your contributions make diart better for everyone. Thank you for your time and dedication! 57 | 58 | Happy coding! 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Université Paris-Saclay 4 | Copyright (c) 2021 CNRS 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

6 | 🌿 Build AI-powered real-time audio applications in a breeze 🌿 7 |

8 | 9 |

10 | PyPI Version 11 | PyPI Downloads 12 | Python Versions 13 | Code size in bytes 14 | License 15 | 16 |

17 | 18 |
19 |

20 | 21 | 💾 Installation 22 | 23 | | 24 | 25 | 🎙️ Stream audio 26 | 27 | | 28 | 29 | 🧠 Models 30 | 31 |
32 | 33 | 📈 Tuning 34 | 35 | | 36 | 37 | 🧠🔗 Pipelines 38 | 39 | | 40 | 41 | 🌐 WebSockets 42 | 43 | | 44 | 45 | 🔬 Research 46 | 47 |

48 |
49 | 50 |
51 | 52 |

53 | 54 |

55 | 56 | ## ⚡ Quick introduction 57 | 58 | Diart is a python framework to build AI-powered real-time audio applications. 59 | Its key feature is the ability to recognize different speakers in real time with state-of-the-art performance, 60 | a task commonly known as "speaker diarization". 61 | 62 | The pipeline `diart.SpeakerDiarization` combines a speaker segmentation and a speaker embedding model 63 | to power an incremental clustering algorithm that gets more accurate as the conversation progresses: 64 | 65 |

66 | 67 |

68 | 69 | With diart you can also create your own custom AI pipeline, benchmark it, 70 | tune its hyper-parameters, and even serve it on the web using websockets. 71 | 72 | **We provide pre-trained pipelines for:** 73 | 74 | - Speaker Diarization 75 | - Voice Activity Detection 76 | - Transcription ([coming soon](https://github.com/juanmc2005/diart/pull/144)) 77 | - [Speaker-Aware Transcription](https://betterprogramming.pub/color-your-captions-streamlining-live-transcriptions-with-diart-and-openais-whisper-6203350234ef) ([coming soon](https://github.com/juanmc2005/diart/pull/147)) 78 | 79 | ## 💾 Installation 80 | 81 | **1) Make sure your system has the following dependencies:** 82 | 83 | ``` 84 | ffmpeg < 4.4 85 | portaudio == 19.6.X 86 | libsndfile >= 1.2.2 87 | ``` 88 | 89 | Alternatively, we provide an `environment.yml` file for a pre-configured conda environment: 90 | 91 | ```shell 92 | conda env create -f diart/environment.yml 93 | conda activate diart 94 | ``` 95 | 96 | **2) Install the package:** 97 | ```shell 98 | pip install diart 99 | ``` 100 | 101 | ### Get access to 🎹 pyannote models 102 | 103 | By default, diart is based on [pyannote.audio](https://github.com/pyannote/pyannote-audio) models from the [huggingface](https://huggingface.co/) hub. 104 | In order to use them, please follow these steps: 105 | 106 | 1) [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model 107 | 2) [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the newest `pyannote/segmentation-3.0` model 108 | 3) [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model 109 | 4) Install [huggingface-cli](https://huggingface.co/docs/huggingface_hub/quick-start#install-the-hub-library) and [log in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with your user access token (or provide it manually in diart CLI or API). 110 | 111 | ## 🎙️ Stream audio 112 | 113 | ### From the command line 114 | 115 | A recorded conversation: 116 | 117 | ```shell 118 | diart.stream /path/to/audio.wav 119 | ``` 120 | 121 | A live conversation: 122 | 123 | ```shell 124 | # Use "microphone:ID" to select a non-default device 125 | # See `python -m sounddevice` for available devices 126 | diart.stream microphone 127 | ``` 128 | 129 | By default, diart runs a speaker diarization pipeline, equivalent to setting `--pipeline SpeakerDiarization`, 130 | but you can also set it to `--pipeline VoiceActivityDetection`. See `diart.stream -h` for more options. 131 | 132 | ### From python 133 | 134 | Use `StreamingInference` to run a pipeline on an audio source and write the results to disk: 135 | 136 | ```python 137 | from diart import SpeakerDiarization 138 | from diart.sources import MicrophoneAudioSource 139 | from diart.inference import StreamingInference 140 | from diart.sinks import RTTMWriter 141 | 142 | pipeline = SpeakerDiarization() 143 | mic = MicrophoneAudioSource() 144 | inference = StreamingInference(pipeline, mic, do_plot=True) 145 | inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm")) 146 | prediction = inference() 147 | ``` 148 | 149 | For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)). 150 | 151 | ## 🧠 Models 152 | 153 | You can use other models with the `--segmentation` and `--embedding` arguments. 154 | Or in python: 155 | 156 | ```python 157 | import diart.models as m 158 | 159 | segmentation = m.SegmentationModel.from_pretrained("model_name") 160 | embedding = m.EmbeddingModel.from_pretrained("model_name") 161 | ``` 162 | 163 | ### Pre-trained models 164 | 165 | Below is a list of all the models currently supported by diart: 166 | 167 | | Model Name | Model Type | CPU Time* | GPU Time* | 168 | |---------------------------------------------------------------------------------------------------------------------------|--------------|-----------|-----------| 169 | | [🤗](https://huggingface.co/pyannote/segmentation) `pyannote/segmentation` (default) | segmentation | 12ms | 8ms | 170 | | [🤗](https://huggingface.co/pyannote/segmentation-3.0) `pyannote/segmentation-3.0` | segmentation | 11ms | 8ms | 171 | | [🤗](https://huggingface.co/pyannote/embedding) `pyannote/embedding` (default) | embedding | 26ms | 12ms | 172 | | [🤗](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM) `hbredin/wespeaker-voxceleb-resnet34-LM` (ONNX) | embedding | 48ms | 15ms | 173 | | [🤗](https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM) `pyannote/wespeaker-voxceleb-resnet34-LM` (PyTorch) | embedding | 150ms | 29ms | 174 | | [🤗](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) `speechbrain/spkrec-xvect-voxceleb` | embedding | 41ms | 15ms | 175 | | [🤗](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb) `speechbrain/spkrec-ecapa-voxceleb` | embedding | 41ms | 14ms | 176 | | [🤗](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb-mel-spec) `speechbrain/spkrec-ecapa-voxceleb-mel-spec` | embedding | 42ms | 14ms | 177 | | [🤗](https://huggingface.co/speechbrain/spkrec-resnet-voxceleb) `speechbrain/spkrec-resnet-voxceleb` | embedding | 41ms | 16ms | 178 | | [🤗](https://huggingface.co/nvidia/speakerverification_en_titanet_large) `nvidia/speakerverification_en_titanet_large` | embedding | 91ms | 16ms | 179 | 180 | The latency of segmentation models is measured in a VAD pipeline (5s chunks). 181 | 182 | The latency of embedding models is measured in a diarization pipeline using `pyannote/segmentation` (also 5s chunks). 183 | 184 | \* CPU: AMD Ryzen 9 - GPU: RTX 4060 Max-Q 185 | 186 | ### Custom models 187 | 188 | Third-party models can be integrated by providing a loader function: 189 | 190 | ```python 191 | from diart import SpeakerDiarization, SpeakerDiarizationConfig 192 | from diart.models import EmbeddingModel, SegmentationModel 193 | 194 | def segmentation_loader(): 195 | # It should take a waveform and return a segmentation tensor 196 | return load_pretrained_model("my_model.ckpt") 197 | 198 | def embedding_loader(): 199 | # It should take (waveform, weights) and return per-speaker embeddings 200 | return load_pretrained_model("my_other_model.ckpt") 201 | 202 | segmentation = SegmentationModel(segmentation_loader) 203 | embedding = EmbeddingModel(embedding_loader) 204 | config = SpeakerDiarizationConfig( 205 | segmentation=segmentation, 206 | embedding=embedding, 207 | ) 208 | pipeline = SpeakerDiarization(config) 209 | ``` 210 | 211 | If you have an ONNX model, you can use `from_onnx()`: 212 | 213 | ```python 214 | from diart.models import EmbeddingModel 215 | 216 | embedding = EmbeddingModel.from_onnx( 217 | model_path="my_model.ckpt", 218 | input_names=["x", "w"], # defaults to ["waveform", "weights"] 219 | output_name="output", # defaults to "embedding" 220 | ) 221 | ``` 222 | 223 | ## 📈 Tune hyper-parameters 224 | 225 | Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs. 226 | 227 | ### From the command line 228 | 229 | ```shell 230 | diart.tune /wav/dir --reference /rttm/dir --output /output/dir 231 | ``` 232 | 233 | See `diart.tune -h` for more options. 234 | 235 | ### From python 236 | 237 | ```python 238 | from diart.optim import Optimizer 239 | 240 | optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir") 241 | optimizer(num_iter=100) 242 | ``` 243 | 244 | This will write results to an sqlite database in `/output/dir`. 245 | 246 | ### Distributed tuning 247 | 248 | For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel. 249 | To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match: 250 | 251 | ```shell 252 | mysql -u root -e "CREATE DATABASE IF NOT EXISTS example" 253 | optuna create-study --study-name "example" --storage "mysql://root@localhost/example" 254 | ``` 255 | 256 | You can now run multiple identical optimizers pointing to this database: 257 | 258 | ```shell 259 | diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example 260 | ``` 261 | 262 | or in python: 263 | 264 | ```python 265 | from diart.optim import Optimizer 266 | from optuna.samplers import TPESampler 267 | import optuna 268 | 269 | db = "mysql://root@localhost/example" 270 | study = optuna.load_study("example", db, TPESampler()) 271 | optimizer = Optimizer("/wav/dir", "/rttm/dir", study) 272 | optimizer(num_iter=100) 273 | ``` 274 | 275 | ## 🧠🔗 Build pipelines 276 | 277 | For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline. 278 | Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately. 279 | 280 | ### Example 281 | 282 | Obtain overlap-aware speaker embeddings from a microphone stream: 283 | 284 | ```python 285 | import rx.operators as ops 286 | import diart.operators as dops 287 | from diart.sources import MicrophoneAudioSource, FileAudioSource 288 | from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding 289 | 290 | segmentation = SpeakerSegmentation.from_pretrained("pyannote/segmentation") 291 | embedding = OverlapAwareSpeakerEmbedding.from_pretrained("pyannote/embedding") 292 | 293 | source = MicrophoneAudioSource() 294 | # To take input from file: 295 | # source = FileAudioSource("", sample_rate=16000) 296 | 297 | # Make sure the models have been trained with this sample rate 298 | print(source.sample_rate) 299 | 300 | stream = mic.stream.pipe( 301 | # Reformat stream to 5s duration and 500ms shift 302 | dops.rearrange_audio_stream(sample_rate=source.sample_rate), 303 | ops.map(lambda wav: (wav, segmentation(wav))), 304 | ops.starmap(embedding) 305 | ).subscribe(on_next=lambda emb: print(emb.shape)) 306 | 307 | source.read() 308 | ``` 309 | 310 | Output: 311 | 312 | ``` 313 | # Shape is (batch_size, num_speakers, embedding_dim) 314 | torch.Size([1, 3, 512]) 315 | torch.Size([1, 3, 512]) 316 | torch.Size([1, 3, 512]) 317 | ... 318 | ``` 319 | 320 | ## 🌐 WebSockets 321 | 322 | Diart is also compatible with the WebSocket protocol to serve pipelines on the web. 323 | 324 | ### From the command line 325 | 326 | ```shell 327 | diart.serve --host 0.0.0.0 --port 7007 328 | diart.client microphone --host --port 7007 329 | ``` 330 | 331 | **Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`. 332 | 333 | See `-h` for more options. 334 | 335 | ### From python 336 | 337 | For customized solutions, a server can also be created in python using the `WebSocketAudioSource`: 338 | 339 | ```python 340 | from diart import SpeakerDiarization 341 | from diart.sources import WebSocketAudioSource 342 | from diart.inference import StreamingInference 343 | 344 | pipeline = SpeakerDiarization() 345 | source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007) 346 | inference = StreamingInference(pipeline, source) 347 | inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm())) 348 | prediction = inference() 349 | ``` 350 | 351 | ## 🔬 Powered by research 352 | 353 | Diart is the official implementation of the paper 354 | [Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](https://github.com/juanmc2005/diart/blob/main/paper.pdf) 355 | by [Juan Manuel Coria](https://juanmc2005.github.io/), 356 | [Hervé Bredin](https://herve.niderb.fr), 357 | [Sahar Ghannay](https://saharghannay.github.io/) 358 | and [Sophie Rosset](https://perso.limsi.fr/rosset/). 359 | 360 | 361 | > We propose to address online speaker diarization as a combination of incremental clustering and local diarization applied to a rolling buffer updated every 500ms. Every single step of the proposed pipeline is designed to take full advantage of the strong ability of a recently proposed end-to-end overlap-aware segmentation to detect and separate overlapping speakers. In particular, we propose a modified version of the statistics pooling layer (initially introduced in the x-vector architecture) to give less weight to frames where the segmentation model predicts simultaneous speakers. Furthermore, we derive cannot-link constraints from the initial segmentation step to prevent two local speakers from being wrongfully merged during the incremental clustering step. Finally, we show how the latency of the proposed approach can be adjusted between 500ms and 5s to match the requirements of a particular use case, and we provide a systematic analysis of the influence of latency on the overall performance (on AMI, DIHARD and VoxConverse). 362 | 363 |

364 | 365 |

366 | 367 | ### Citation 368 | 369 | If you found diart useful, please make sure to cite our paper: 370 | 371 | ```bibtex 372 | @inproceedings{diart, 373 | author={Coria, Juan M. and Bredin, Hervé and Ghannay, Sahar and Rosset, Sophie}, 374 | booktitle={2021 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU)}, 375 | title={Overlap-Aware Low-Latency Online Speaker Diarization Based on End-to-End Local Segmentation}, 376 | year={2021}, 377 | pages={1139-1146}, 378 | doi={10.1109/ASRU51503.2021.9688044}, 379 | } 380 | ``` 381 | 382 | ### Reproducibility 383 | 384 | ![Results table](https://github.com/juanmc2005/diart/blob/main/table1.png?raw=true) 385 | 386 | **Important:** We highly recommend installing `pyannote.audio<3.1` to reproduce these results. 387 | For more information, see [this issue](https://github.com/juanmc2005/diart/issues/214). 388 | 389 | Diart aims to be lightweight and capable of real-time streaming in practical scenarios. 390 | Its performance is very close to what is reported in the paper (and sometimes even a bit better). 391 | 392 | To obtain the best results, make sure to use the following hyper-parameters: 393 | 394 | | Dataset | latency | tau | rho | delta | 395 | |-------------|---------|--------|--------|-------| 396 | | DIHARD III | any | 0.555 | 0.422 | 1.517 | 397 | | AMI | any | 0.507 | 0.006 | 1.057 | 398 | | VoxConverse | any | 0.576 | 0.915 | 0.648 | 399 | | DIHARD II | 1s | 0.619 | 0.326 | 0.997 | 400 | | DIHARD II | 5s | 0.555 | 0.422 | 1.517 | 401 | 402 | `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration: 403 | 404 | ```shell 405 | diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021 406 | ``` 407 | 408 | or using the inference API: 409 | 410 | ```python 411 | from diart.inference import Benchmark, Parallelize 412 | from diart import SpeakerDiarization, SpeakerDiarizationConfig 413 | from diart.models import SegmentationModel 414 | 415 | benchmark = Benchmark("/wav/dir", "/rttm/dir") 416 | 417 | model_name = "pyannote/segmentation@Interspeech2021" 418 | model = SegmentationModel.from_pretrained(model_name) 419 | config = SpeakerDiarizationConfig( 420 | # Set the segmentation model used in the paper 421 | segmentation=model, 422 | step=0.5, 423 | latency=0.5, 424 | tau_active=0.555, 425 | rho_update=0.422, 426 | delta_new=1.517 427 | ) 428 | benchmark(SpeakerDiarization, config) 429 | 430 | # Run the same benchmark in parallel 431 | p_benchmark = Parallelize(benchmark, num_workers=4) 432 | if __name__ == "__main__": # Needed for multiprocessing 433 | p_benchmark(SpeakerDiarization, config) 434 | ``` 435 | 436 | This pre-calculates model outputs in batches, so it runs a lot faster. 437 | See `diart.benchmark -h` for more options. 438 | 439 | For convenience and to facilitate future comparisons, we also provide the 440 | expected outputs 441 | of the paper implementation in RTTM format for every entry of Table 1 and Figure 5. 442 | This includes the VBx offline topline as well as our proposed online approach with 443 | latencies 500ms, 1s, 2s, 3s, 4s, and 5s. 444 | 445 | ![Figure 5](https://github.com/juanmc2005/diart/blob/main/figure5.png?raw=true) 446 | 447 | ## 📑 License 448 | 449 | ``` 450 | MIT License 451 | 452 | Copyright (c) 2021 Université Paris-Saclay 453 | Copyright (c) 2021 CNRS 454 | 455 | Permission is hereby granted, free of charge, to any person obtaining a copy 456 | of this software and associated documentation files (the "Software"), to deal 457 | in the Software without restriction, including without limitation the rights 458 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 459 | copies of the Software, and to permit persons to whom the Software is 460 | furnished to do so, subject to the following conditions: 461 | 462 | The above copyright notice and this permission notice shall be included in all 463 | copies or substantial portions of the Software. 464 | 465 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 466 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 467 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 468 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 469 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 470 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 471 | SOFTWARE. 472 | ``` 473 | 474 |

Logo generated by DesignEvo free logo designer

475 | -------------------------------------------------------------------------------- /assets/models/embedding_uint8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/assets/models/embedding_uint8.onnx -------------------------------------------------------------------------------- /assets/models/segmentation_uint8.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/assets/models/segmentation_uint8.onnx -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/demo.gif -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/docs/_static/logo.png -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "diart" 10 | copyright = "2023, Juan Manuel Coria" 11 | author = "Juan Manuel Coria" 12 | release = "v0.9" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | "autoapi.extension", 19 | "sphinx.ext.coverage", 20 | "sphinx.ext.napoleon", 21 | "sphinx_mdinclude", 22 | ] 23 | 24 | autoapi_dirs = ["../src/diart"] 25 | autoapi_options = [ 26 | "members", 27 | "undoc-members", 28 | "show-inheritance", 29 | "show-module-summary", 30 | "special-members", 31 | "imported-members", 32 | ] 33 | 34 | templates_path = ["_templates"] 35 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 36 | 37 | # -- Options for autodoc ---------------------------------------------------- 38 | # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration 39 | 40 | # Automatically extract typehints when specified and place them in 41 | # descriptions of the relevant function/method. 42 | autodoc_typehints = "description" 43 | 44 | # Don't show class signature with the class' name. 45 | autodoc_class_signature = "separated" 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 49 | 50 | html_theme = "furo" 51 | html_static_path = ["_static"] 52 | html_logo = "_static/logo.png" 53 | html_title = "diart documentation" 54 | 55 | 56 | def skip_submodules(app, what, name, obj, skip, options): 57 | return ( 58 | name.endswith("__init__") 59 | or name.startswith("diart.console") 60 | or name.startswith("diart.argdoc") 61 | ) 62 | 63 | 64 | def setup(sphinx): 65 | sphinx.connect("autoapi-skip-member", skip_submodules) 66 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Get started with diart 2 | ====================== 3 | 4 | .. mdinclude:: ../README.md 5 | 6 | 7 | Useful Links 8 | ============ 9 | 10 | .. toctree:: 11 | :maxdepth: 1 12 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==6.2.1 2 | sphinx-autoapi==3.0.0 3 | sphinx-mdinclude==0.5.3 4 | furo==2023.9.10 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diart 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - portaudio=19.6.* 8 | - pysoundfile=0.12.* 9 | - ffmpeg[version='<4.4'] 10 | - pip 11 | - pip: 12 | - . -------------------------------------------------------------------------------- /figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/figure1.png -------------------------------------------------------------------------------- /figure5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/figure5.png -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/logo.jpg -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/paper.pdf -------------------------------------------------------------------------------- /pipeline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/pipeline.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.9.0", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.2,<2.0.0 2 | matplotlib>=3.3.3,<4.0.0 3 | rx>=3.2.0 4 | scipy>=1.6.0 5 | sounddevice>=0.4.2 6 | einops>=0.3.0 7 | tqdm>=4.64.0 8 | pandas>=1.4.2 9 | torch>=1.12.1 10 | torchvision>=0.14.0 11 | torchaudio>=2.0.2 12 | pyannote.audio>=2.1.1 13 | requests>=2.31.0 14 | pyannote.core>=4.5 15 | pyannote.database>=4.1.1 16 | pyannote.metrics>=3.2 17 | optuna>=2.10 18 | websocket-server>=0.6.4 19 | websocket-client>=0.58.0 20 | rich>=12.5.1 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name=diart 3 | version=0.9.2 4 | author=Juan Manuel Coria 5 | description=A python framework to build AI for real-time speech 6 | long_description=file: README.md 7 | long_description_content_type=text/markdown 8 | keywords=speaker diarization, streaming, online, real time, rxpy 9 | url=https://github.com/juanmc2005/diart 10 | license=MIT 11 | classifiers= 12 | Development Status :: 4 - Beta 13 | License :: OSI Approved :: MIT License 14 | Topic :: Multimedia :: Sound/Audio :: Analysis 15 | Topic :: Multimedia :: Sound/Audio :: Speech 16 | Topic :: Scientific/Engineering :: Artificial Intelligence 17 | 18 | [options] 19 | package_dir= 20 | =src 21 | packages=find: 22 | install_requires= 23 | numpy>=1.20.2,<2.0.0 24 | matplotlib>=3.3.3,<4.0.0 25 | rx>=3.2.0 26 | scipy>=1.6.0 27 | sounddevice>=0.4.2 28 | einops>=0.3.0 29 | tqdm>=4.64.0 30 | pandas>=1.4.2 31 | torch>=1.12.1 32 | torchvision>=0.14.0 33 | torchaudio>=2.0.2 34 | pyannote.audio>=2.1.1 35 | requests>=2.31.0 36 | pyannote.core>=4.5 37 | pyannote.database>=4.1.1 38 | pyannote.metrics>=3.2 39 | optuna>=2.10 40 | websocket-server>=0.6.4 41 | websocket-client>=0.58.0 42 | rich>=12.5.1 43 | 44 | [options.extras_require] 45 | tests= 46 | pytest>=7.4.0,<8.0.0 47 | onnxruntime==1.18.0 48 | 49 | [options.packages.find] 50 | where=src 51 | 52 | [options.entry_points] 53 | console_scripts= 54 | diart.stream=diart.console.stream:run 55 | diart.benchmark=diart.console.benchmark:run 56 | diart.tune=diart.console.tune:run 57 | diart.serve=diart.console.serve:run 58 | diart.client=diart.console.client:run 59 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | setuptools.setup() 3 | -------------------------------------------------------------------------------- /src/diart/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import ( 2 | SpeakerDiarization, 3 | Pipeline, 4 | SpeakerDiarizationConfig, 5 | PipelineConfig, 6 | VoiceActivityDetection, 7 | VoiceActivityDetectionConfig, 8 | ) 9 | -------------------------------------------------------------------------------- /src/diart/argdoc.py: -------------------------------------------------------------------------------- 1 | SEGMENTATION = "Segmentation model name from pyannote" 2 | EMBEDDING = "Embedding model name from pyannote" 3 | DURATION = "Chunk duration (in seconds)" 4 | STEP = "Sliding window step (in seconds)" 5 | LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION" 6 | TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1" 7 | RHO = "Speech ratio threshold to decide if centroids are updated with a given speaker. 0 <= RHO <= 1" 8 | DELTA = "Embedding-to-centroid distance threshold to flag a speaker as known or new. 0 <= DELTA <= 2" 9 | GAMMA = "Parameter gamma for overlapped speech penalty" 10 | BETA = "Parameter beta for overlapped speech penalty" 11 | MAX_SPEAKERS = "Maximum number of speakers" 12 | CPU = "Force models to run on CPU" 13 | BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency" 14 | NUM_WORKERS = "Number of parallel workers" 15 | OUTPUT = "Directory to store the system's output in RTTM format" 16 | HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login" 17 | SAMPLE_RATE = "Sample rate of the audio stream" 18 | NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like Nvidia's NeMo or ECAPA-TDNN" 19 | -------------------------------------------------------------------------------- /src/diart/audio.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Text, Union 3 | 4 | import torch 5 | import torchaudio 6 | from torchaudio.functional import resample 7 | 8 | torchaudio.set_audio_backend("soundfile") 9 | 10 | 11 | FilePath = Union[Text, Path] 12 | 13 | 14 | class AudioLoader: 15 | def __init__(self, sample_rate: int, mono: bool = True): 16 | self.sample_rate = sample_rate 17 | self.mono = mono 18 | 19 | def load(self, filepath: FilePath) -> torch.Tensor: 20 | """Load an audio file into a torch.Tensor. 21 | 22 | Parameters 23 | ---------- 24 | filepath : FilePath 25 | Path to an audio file 26 | 27 | Returns 28 | ------- 29 | waveform : torch.Tensor, shape (channels, samples) 30 | """ 31 | waveform, sample_rate = torchaudio.load(filepath) 32 | # Get channel mean if mono 33 | if self.mono and waveform.shape[0] > 1: 34 | waveform = waveform.mean(dim=0, keepdim=True) 35 | # Resample if needed 36 | if self.sample_rate != sample_rate: 37 | waveform = resample(waveform, sample_rate, self.sample_rate) 38 | return waveform 39 | 40 | @staticmethod 41 | def get_duration(filepath: FilePath) -> float: 42 | """Get audio file duration in seconds. 43 | 44 | Parameters 45 | ---------- 46 | filepath : FilePath 47 | Path to an audio file. 48 | 49 | Returns 50 | ------- 51 | duration : float 52 | Duration in seconds. 53 | """ 54 | info = torchaudio.info(filepath) 55 | return info.num_frames / info.sample_rate 56 | -------------------------------------------------------------------------------- /src/diart/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregation import ( 2 | AggregationStrategy, 3 | HammingWeightedAverageStrategy, 4 | AverageStrategy, 5 | FirstOnlyStrategy, 6 | DelayedAggregation, 7 | ) 8 | from .clustering import OnlineSpeakerClustering 9 | from .embedding import ( 10 | SpeakerEmbedding, 11 | OverlappedSpeechPenalty, 12 | EmbeddingNormalization, 13 | OverlapAwareSpeakerEmbedding, 14 | ) 15 | from .segmentation import SpeakerSegmentation 16 | from .diarization import SpeakerDiarization, SpeakerDiarizationConfig 17 | from .base import PipelineConfig, Pipeline 18 | from .utils import Binarize, Resample, AdjustVolume 19 | from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig 20 | -------------------------------------------------------------------------------- /src/diart/blocks/aggregation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, List 3 | 4 | import numpy as np 5 | from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature 6 | from typing_extensions import Literal 7 | 8 | 9 | class AggregationStrategy(ABC): 10 | """Abstract class representing a strategy to aggregate overlapping buffers 11 | 12 | Parameters 13 | ---------- 14 | cropping_mode: ("strict", "loose", "center"), optional 15 | Defines the mode to crop buffer chunks as in pyannote.core. 16 | See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop 17 | Defaults to "loose". 18 | """ 19 | 20 | def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"): 21 | assert cropping_mode in [ 22 | "strict", 23 | "loose", 24 | "center", 25 | ], f"Invalid cropping mode `{cropping_mode}`" 26 | self.cropping_mode = cropping_mode 27 | 28 | @staticmethod 29 | def build( 30 | name: Literal["mean", "hamming", "first"], 31 | cropping_mode: Literal["strict", "loose", "center"] = "loose", 32 | ) -> "AggregationStrategy": 33 | """Build an AggregationStrategy instance based on its name""" 34 | assert name in ("mean", "hamming", "first") 35 | if name == "mean": 36 | return AverageStrategy(cropping_mode) 37 | elif name == "hamming": 38 | return HammingWeightedAverageStrategy(cropping_mode) 39 | else: 40 | return FirstOnlyStrategy(cropping_mode) 41 | 42 | def __call__( 43 | self, buffers: List[SlidingWindowFeature], focus: Segment 44 | ) -> SlidingWindowFeature: 45 | """Aggregate chunks over a specific region. 46 | 47 | Parameters 48 | ---------- 49 | buffers: list of SlidingWindowFeature, shapes (frames, speakers) 50 | Buffers to aggregate 51 | focus: Segment 52 | Region to aggregate that is shared among the buffers 53 | 54 | Returns 55 | ------- 56 | aggregation: SlidingWindowFeature, shape (cropped_frames, speakers) 57 | Aggregated values over the focus region 58 | """ 59 | aggregation = self.aggregate(buffers, focus) 60 | resolution = focus.duration / aggregation.shape[0] 61 | resolution = SlidingWindow( 62 | start=focus.start, duration=resolution, step=resolution 63 | ) 64 | return SlidingWindowFeature(aggregation, resolution) 65 | 66 | @abstractmethod 67 | def aggregate( 68 | self, buffers: List[SlidingWindowFeature], focus: Segment 69 | ) -> np.ndarray: 70 | pass 71 | 72 | 73 | class HammingWeightedAverageStrategy(AggregationStrategy): 74 | """Compute the average weighted by the corresponding Hamming-window aligned to each buffer""" 75 | 76 | def aggregate( 77 | self, buffers: List[SlidingWindowFeature], focus: Segment 78 | ) -> np.ndarray: 79 | num_frames, num_speakers = buffers[0].data.shape 80 | hamming, intersection = [], [] 81 | for buffer in buffers: 82 | # Crop buffer to focus region 83 | b = buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) 84 | # Crop Hamming window to focus region 85 | h = np.expand_dims(np.hamming(num_frames), axis=-1) 86 | h = SlidingWindowFeature(h, buffer.sliding_window) 87 | h = h.crop(focus, mode=self.cropping_mode, fixed=focus.duration) 88 | hamming.append(h.data) 89 | intersection.append(b.data) 90 | hamming, intersection = np.stack(hamming), np.stack(intersection) 91 | # Calculate weighted mean 92 | return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0) 93 | 94 | 95 | class AverageStrategy(AggregationStrategy): 96 | """Compute a simple average over the focus region""" 97 | 98 | def aggregate( 99 | self, buffers: List[SlidingWindowFeature], focus: Segment 100 | ) -> np.ndarray: 101 | # Stack all overlapping regions 102 | intersection = np.stack( 103 | [ 104 | buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration) 105 | for buffer in buffers 106 | ] 107 | ) 108 | return np.mean(intersection, axis=0) 109 | 110 | 111 | class FirstOnlyStrategy(AggregationStrategy): 112 | """Instead of aggregating, keep the first focus region in the buffer list""" 113 | 114 | def aggregate( 115 | self, buffers: List[SlidingWindowFeature], focus: Segment 116 | ) -> np.ndarray: 117 | return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration) 118 | 119 | 120 | class DelayedAggregation: 121 | """Aggregate aligned overlapping windows of the same duration 122 | across sliding buffers with a specific step and latency. 123 | 124 | Parameters 125 | ---------- 126 | step: float 127 | Shift between two consecutive buffers, in seconds. 128 | latency: float, optional 129 | Desired latency, in seconds. Defaults to step. 130 | The higher the latency, the more overlapping windows to aggregate. 131 | strategy: ("mean", "hamming", "first"), optional 132 | Specifies how to aggregate overlapping windows. Defaults to "hamming". 133 | "mean": simple average 134 | "hamming": average weighted by the Hamming window values (aligned to the buffer) 135 | "first": no aggregation, pick the first overlapping window 136 | cropping_mode: ("strict", "loose", "center"), optional 137 | Defines the mode to crop buffer chunks as in pyannote.core. 138 | See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop 139 | Defaults to "loose". 140 | 141 | Example 142 | -------- 143 | >>> duration = 5 144 | >>> frames = 500 145 | >>> step = 0.5 146 | >>> speakers = 2 147 | >>> start_time = 10 148 | >>> resolution = duration / frames 149 | >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean") 150 | >>> buffers = [ 151 | >>> SlidingWindowFeature( 152 | >>> np.random.rand(frames, speakers), 153 | >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution) 154 | >>> ) 155 | >>> for i in range(dagg.num_overlapping_windows) 156 | >>> ] 157 | >>> dagg.num_overlapping_windows 158 | ... 4 159 | >>> dagg(buffers).data.shape 160 | ... (51, 2) # Rounding errors are possible when cropping the buffers 161 | """ 162 | 163 | def __init__( 164 | self, 165 | step: float, 166 | latency: Optional[float] = None, 167 | strategy: Literal["mean", "hamming", "first"] = "hamming", 168 | cropping_mode: Literal["strict", "loose", "center"] = "loose", 169 | ): 170 | self.step = step 171 | self.latency = latency 172 | self.strategy = strategy 173 | assert cropping_mode in [ 174 | "strict", 175 | "loose", 176 | "center", 177 | ], f"Invalid cropping mode `{cropping_mode}`" 178 | self.cropping_mode = cropping_mode 179 | 180 | if self.latency is None: 181 | self.latency = self.step 182 | 183 | assert self.step <= self.latency, "Invalid latency requested" 184 | 185 | self.num_overlapping_windows = int(round(self.latency / self.step)) 186 | self.aggregate = AggregationStrategy.build(self.strategy, self.cropping_mode) 187 | 188 | def _prepend( 189 | self, 190 | output_window: SlidingWindowFeature, 191 | output_region: Segment, 192 | buffers: List[SlidingWindowFeature], 193 | ): 194 | # FIXME instead of prepending the output of the first chunk, 195 | # add padding of `chunk_duration - latency` seconds at the 196 | # beginning of the stream so scores can be aggregated accordingly. 197 | # Remember to shift predictions by the padding. 198 | last_buffer = buffers[-1].extent 199 | # Prepend prediction until we match the latency in case of first buffer 200 | if len(buffers) == 1 and last_buffer.start == 0: 201 | num_frames = output_window.data.shape[0] 202 | first_region = Segment(0, output_region.end) 203 | first_output = buffers[0].crop( 204 | first_region, mode=self.cropping_mode, fixed=first_region.duration 205 | ) 206 | first_output[-num_frames:] = output_window.data 207 | resolution = output_region.end / first_output.shape[0] 208 | output_window = SlidingWindowFeature( 209 | first_output, 210 | SlidingWindow(start=0, duration=resolution, step=resolution), 211 | ) 212 | return output_window 213 | 214 | def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature: 215 | # Determine overlapping region to aggregate 216 | start = buffers[-1].extent.end - self.latency 217 | region = Segment(start, start + self.step) 218 | return self._prepend(self.aggregate(buffers, region), region, buffers) 219 | -------------------------------------------------------------------------------- /src/diart/blocks/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Any, Tuple, Sequence, Text 4 | 5 | from pyannote.core import SlidingWindowFeature 6 | from pyannote.metrics.base import BaseMetric 7 | 8 | from .. import utils 9 | from ..audio import FilePath, AudioLoader 10 | 11 | 12 | @dataclass 13 | class HyperParameter: 14 | """Represents a pipeline hyper-parameter that can be tuned by diart""" 15 | 16 | name: Text 17 | """Name of the hyper-parameter (e.g. tau_active)""" 18 | low: float 19 | """Lowest value that this parameter can take""" 20 | high: float 21 | """Highest value that this parameter can take""" 22 | 23 | @staticmethod 24 | def from_name(name: Text) -> "HyperParameter": 25 | """Create a HyperParameter object given its name. 26 | 27 | Parameters 28 | ---------- 29 | name: str 30 | Name of the hyper-parameter 31 | 32 | Returns 33 | ------- 34 | HyperParameter 35 | """ 36 | if name == "tau_active": 37 | return TauActive 38 | if name == "rho_update": 39 | return RhoUpdate 40 | if name == "delta_new": 41 | return DeltaNew 42 | raise ValueError(f"Hyper-parameter '{name}' not recognized") 43 | 44 | 45 | TauActive = HyperParameter("tau_active", low=0, high=1) 46 | RhoUpdate = HyperParameter("rho_update", low=0, high=1) 47 | DeltaNew = HyperParameter("delta_new", low=0, high=2) 48 | 49 | 50 | class PipelineConfig(ABC): 51 | """Configuration containing the required 52 | parameters to build and run a pipeline""" 53 | 54 | @property 55 | @abstractmethod 56 | def duration(self) -> float: 57 | """The duration of an input audio chunk (in seconds)""" 58 | pass 59 | 60 | @property 61 | @abstractmethod 62 | def step(self) -> float: 63 | """The step between two consecutive input audio chunks (in seconds)""" 64 | pass 65 | 66 | @property 67 | @abstractmethod 68 | def latency(self) -> float: 69 | """The algorithmic latency of the pipeline (in seconds). 70 | At time `t` of the audio stream, the pipeline will 71 | output predictions for time `t - latency`. 72 | """ 73 | pass 74 | 75 | @property 76 | @abstractmethod 77 | def sample_rate(self) -> int: 78 | """The sample rate of the input audio stream""" 79 | pass 80 | 81 | def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]: 82 | file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath) 83 | right = utils.get_padding_right(self.latency, self.step) 84 | left = utils.get_padding_left(file_duration + right, self.duration) 85 | return left, right 86 | 87 | 88 | class Pipeline(ABC): 89 | """Represents a streaming audio pipeline""" 90 | 91 | @staticmethod 92 | @abstractmethod 93 | def get_config_class() -> type: 94 | pass 95 | 96 | @staticmethod 97 | @abstractmethod 98 | def suggest_metric() -> BaseMetric: 99 | pass 100 | 101 | @staticmethod 102 | @abstractmethod 103 | def hyper_parameters() -> Sequence[HyperParameter]: 104 | pass 105 | 106 | @property 107 | @abstractmethod 108 | def config(self) -> PipelineConfig: 109 | pass 110 | 111 | @abstractmethod 112 | def reset(self): 113 | pass 114 | 115 | @abstractmethod 116 | def set_timestamp_shift(self, shift: float): 117 | pass 118 | 119 | @abstractmethod 120 | def __call__( 121 | self, waveforms: Sequence[SlidingWindowFeature] 122 | ) -> Sequence[Tuple[Any, SlidingWindowFeature]]: 123 | """Runs the next steps of the pipeline 124 | given a list of consecutive audio chunks. 125 | 126 | Parameters 127 | ---------- 128 | waveforms: Sequence[SlidingWindowFeature] 129 | Consecutive chunk waveforms for the pipeline to ingest 130 | 131 | Returns 132 | ------- 133 | Sequence[Tuple[Any, SlidingWindowFeature]] 134 | For each input waveform, a tuple containing 135 | the pipeline output and its respective audio 136 | """ 137 | pass 138 | -------------------------------------------------------------------------------- /src/diart/blocks/clustering.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Iterable, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from pyannote.core import SlidingWindowFeature 6 | 7 | from ..mapping import SpeakerMap, SpeakerMapBuilder 8 | 9 | 10 | class OnlineSpeakerClustering: 11 | """Implements constrained incremental online clustering of speakers and manages cluster centers. 12 | 13 | Parameters 14 | ---------- 15 | tau_active:float 16 | Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output 17 | activation of the local segmentation model. 18 | rho_update: float 19 | Threshold for considering the extracted embedding when updating the centroid of the local speaker. 20 | The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration 21 | of a given local speaker is greater than this threshold. 22 | delta_new: float 23 | Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all 24 | centroids is larger than delta_new, then a new centroid is created for the current speaker. 25 | metric: str. Defaults to "cosine". 26 | The distance metric to use. 27 | max_speakers: int 28 | Maximum number of global speakers to track through a conversation. Defaults to 20. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | tau_active: float, 34 | rho_update: float, 35 | delta_new: float, 36 | metric: Optional[str] = "cosine", 37 | max_speakers: int = 20, 38 | ): 39 | self.tau_active = tau_active 40 | self.rho_update = rho_update 41 | self.delta_new = delta_new 42 | self.metric = metric 43 | self.max_speakers = max_speakers 44 | self.centers: Optional[np.ndarray] = None 45 | self.active_centers = set() 46 | self.blocked_centers = set() 47 | 48 | @property 49 | def num_free_centers(self) -> int: 50 | return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers 51 | 52 | @property 53 | def num_known_speakers(self) -> int: 54 | return len(self.active_centers) 55 | 56 | @property 57 | def num_blocked_speakers(self) -> int: 58 | return len(self.blocked_centers) 59 | 60 | @property 61 | def inactive_centers(self) -> List[int]: 62 | return [ 63 | c 64 | for c in range(self.max_speakers) 65 | if c not in self.active_centers or c in self.blocked_centers 66 | ] 67 | 68 | def get_next_center_position(self) -> Optional[int]: 69 | for center in range(self.max_speakers): 70 | if center not in self.active_centers and center not in self.blocked_centers: 71 | return center 72 | 73 | def init_centers(self, dimension: int): 74 | """Initializes the speaker centroid matrix 75 | 76 | Parameters 77 | ---------- 78 | dimension: int 79 | Dimension of embeddings used for representing a speaker. 80 | """ 81 | self.centers = np.zeros((self.max_speakers, dimension)) 82 | self.active_centers = set() 83 | self.blocked_centers = set() 84 | 85 | def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray): 86 | """Updates the speaker centroids given a list of assignments and local speaker embeddings 87 | 88 | Parameters 89 | ---------- 90 | assignments: Iterable[Tuple[int, int]]) 91 | An iterable of tuples with two elements having the first element as the source speaker 92 | and the second element as the target speaker. 93 | embeddings: np.ndarray, shape (local_speakers, embedding_dim) 94 | Matrix containing embeddings for all local speakers. 95 | """ 96 | if self.centers is not None: 97 | for l_spk, g_spk in assignments: 98 | assert g_spk in self.active_centers, "Cannot update unknown centers" 99 | self.centers[g_spk] += embeddings[l_spk] 100 | 101 | def add_center(self, embedding: np.ndarray) -> int: 102 | """Add a new speaker centroid initialized to a given embedding 103 | 104 | Parameters 105 | ---------- 106 | embedding: np.ndarray 107 | Embedding vector of some local speaker 108 | 109 | Returns 110 | ------- 111 | center_index: int 112 | Index of the created center 113 | """ 114 | center = self.get_next_center_position() 115 | self.centers[center] = embedding 116 | self.active_centers.add(center) 117 | return center 118 | 119 | def identify( 120 | self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor 121 | ) -> SpeakerMap: 122 | """Identify the centroids to which the input speaker embeddings belong. 123 | 124 | Parameters 125 | ---------- 126 | segmentation: np.ndarray, shape (frames, local_speakers) 127 | Matrix of segmentation outputs 128 | embeddings: np.ndarray, shape (local_speakers, embedding_dim) 129 | Matrix of embeddings 130 | 131 | Returns 132 | ------- 133 | speaker_map: SpeakerMap 134 | A mapping from local speakers to global speakers. 135 | """ 136 | embeddings = embeddings.detach().cpu().numpy() 137 | active_speakers = np.where( 138 | np.max(segmentation.data, axis=0) >= self.tau_active 139 | )[0] 140 | long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[ 141 | 0 142 | ] 143 | # Remove speakers that have NaN embeddings 144 | no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0] 145 | active_speakers = np.intersect1d(active_speakers, no_nan_embeddings) 146 | 147 | num_local_speakers = segmentation.data.shape[1] 148 | 149 | if self.centers is None: 150 | self.init_centers(embeddings.shape[1]) 151 | assignments = [ 152 | (spk, self.add_center(embeddings[spk])) for spk in active_speakers 153 | ] 154 | return SpeakerMapBuilder.hard_map( 155 | shape=(num_local_speakers, self.max_speakers), 156 | assignments=assignments, 157 | maximize=False, 158 | ) 159 | 160 | # Obtain a mapping based on distances between embeddings and centers 161 | dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric) 162 | # Remove any assignments containing invalid speakers 163 | inactive_speakers = np.array( 164 | [spk for spk in range(num_local_speakers) if spk not in active_speakers] 165 | ) 166 | dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers) 167 | # Keep assignments under the distance threshold 168 | valid_map = dist_map.unmap_threshold(self.delta_new) 169 | 170 | # Some speakers might be unidentified 171 | missed_speakers = [ 172 | s for s in active_speakers if not valid_map.is_source_speaker_mapped(s) 173 | ] 174 | 175 | # Add assignments to new centers if possible 176 | new_center_speakers = [] 177 | for spk in missed_speakers: 178 | has_space = len(new_center_speakers) < self.num_free_centers 179 | if has_space and spk in long_speakers: 180 | # Flag as a new center 181 | new_center_speakers.append(spk) 182 | else: 183 | # Cannot create a new center 184 | # Get global speakers in order of preference 185 | preferences = np.argsort(dist_map.mapping_matrix[spk, :]) 186 | preferences = [ 187 | g_spk for g_spk in preferences if g_spk in self.active_centers 188 | ] 189 | # Get the free global speakers among the preferences 190 | _, g_assigned = valid_map.valid_assignments() 191 | free = [g_spk for g_spk in preferences if g_spk not in g_assigned] 192 | if free: 193 | # The best global speaker is the closest free one 194 | valid_map = valid_map.set_source_speaker(spk, free[0]) 195 | 196 | # Update known centers 197 | to_update = [ 198 | (ls, gs) 199 | for ls, gs in zip(*valid_map.valid_assignments()) 200 | if ls not in missed_speakers and ls in long_speakers 201 | ] 202 | self.update(to_update, embeddings) 203 | 204 | # Add new centers 205 | for spk in new_center_speakers: 206 | valid_map = valid_map.set_source_speaker( 207 | spk, self.add_center(embeddings[spk]) 208 | ) 209 | 210 | return valid_map 211 | 212 | def __call__( 213 | self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor 214 | ) -> SlidingWindowFeature: 215 | return SlidingWindowFeature( 216 | self.identify(segmentation, embeddings).apply(segmentation.data), 217 | segmentation.sliding_window, 218 | ) 219 | -------------------------------------------------------------------------------- /src/diart/blocks/diarization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Sequence 4 | 5 | import numpy as np 6 | import torch 7 | from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment 8 | from pyannote.metrics.base import BaseMetric 9 | from pyannote.metrics.diarization import DiarizationErrorRate 10 | from typing_extensions import Literal 11 | 12 | from . import base 13 | from .aggregation import DelayedAggregation 14 | from .clustering import OnlineSpeakerClustering 15 | from .embedding import OverlapAwareSpeakerEmbedding 16 | from .segmentation import SpeakerSegmentation 17 | from .utils import Binarize 18 | from .. import models as m 19 | 20 | 21 | class SpeakerDiarizationConfig(base.PipelineConfig): 22 | def __init__( 23 | self, 24 | segmentation: m.SegmentationModel | None = None, 25 | embedding: m.EmbeddingModel | None = None, 26 | duration: float = 5, 27 | step: float = 0.5, 28 | latency: float | Literal["max", "min"] | None = None, 29 | tau_active: float = 0.6, 30 | rho_update: float = 0.3, 31 | delta_new: float = 1, 32 | gamma: float = 3, 33 | beta: float = 10, 34 | max_speakers: int = 20, 35 | normalize_embedding_weights: bool = False, 36 | device: torch.device | None = None, 37 | sample_rate: int = 16000, 38 | **kwargs, 39 | ): 40 | # Default segmentation model is pyannote/segmentation 41 | self.segmentation = segmentation or m.SegmentationModel.from_pyannote( 42 | "pyannote/segmentation" 43 | ) 44 | 45 | # Default embedding model is pyannote/embedding 46 | self.embedding = embedding or m.EmbeddingModel.from_pyannote( 47 | "pyannote/embedding" 48 | ) 49 | 50 | self._duration = duration 51 | self._sample_rate = sample_rate 52 | 53 | # Latency defaults to the step duration 54 | self._step = step 55 | self._latency = latency 56 | if self._latency is None or self._latency == "min": 57 | self._latency = self._step 58 | elif self._latency == "max": 59 | self._latency = self._duration 60 | 61 | self.tau_active = tau_active 62 | self.rho_update = rho_update 63 | self.delta_new = delta_new 64 | self.gamma = gamma 65 | self.beta = beta 66 | self.max_speakers = max_speakers 67 | self.normalize_embedding_weights = normalize_embedding_weights 68 | self.device = device or torch.device( 69 | "cuda" if torch.cuda.is_available() else "cpu" 70 | ) 71 | 72 | @property 73 | def duration(self) -> float: 74 | return self._duration 75 | 76 | @property 77 | def step(self) -> float: 78 | return self._step 79 | 80 | @property 81 | def latency(self) -> float: 82 | return self._latency 83 | 84 | @property 85 | def sample_rate(self) -> int: 86 | return self._sample_rate 87 | 88 | 89 | class SpeakerDiarization(base.Pipeline): 90 | def __init__(self, config: SpeakerDiarizationConfig | None = None): 91 | self._config = SpeakerDiarizationConfig() if config is None else config 92 | 93 | msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" 94 | assert self._config.step <= self._config.latency <= self._config.duration, msg 95 | 96 | self.segmentation = SpeakerSegmentation( 97 | self._config.segmentation, self._config.device 98 | ) 99 | self.embedding = OverlapAwareSpeakerEmbedding( 100 | self._config.embedding, 101 | self._config.gamma, 102 | self._config.beta, 103 | norm=1, 104 | normalize_weights=self._config.normalize_embedding_weights, 105 | device=self._config.device, 106 | ) 107 | self.pred_aggregation = DelayedAggregation( 108 | self._config.step, 109 | self._config.latency, 110 | strategy="hamming", 111 | cropping_mode="loose", 112 | ) 113 | self.audio_aggregation = DelayedAggregation( 114 | self._config.step, 115 | self._config.latency, 116 | strategy="first", 117 | cropping_mode="center", 118 | ) 119 | self.binarize = Binarize(self._config.tau_active) 120 | 121 | # Internal state, handle with care 122 | self.timestamp_shift = 0 123 | self.clustering = None 124 | self.chunk_buffer, self.pred_buffer = [], [] 125 | self.reset() 126 | 127 | @staticmethod 128 | def get_config_class() -> type: 129 | return SpeakerDiarizationConfig 130 | 131 | @staticmethod 132 | def suggest_metric() -> BaseMetric: 133 | return DiarizationErrorRate(collar=0, skip_overlap=False) 134 | 135 | @staticmethod 136 | def hyper_parameters() -> Sequence[base.HyperParameter]: 137 | return [base.TauActive, base.RhoUpdate, base.DeltaNew] 138 | 139 | @property 140 | def config(self) -> SpeakerDiarizationConfig: 141 | return self._config 142 | 143 | def set_timestamp_shift(self, shift: float): 144 | self.timestamp_shift = shift 145 | 146 | def reset(self): 147 | self.set_timestamp_shift(0) 148 | self.clustering = OnlineSpeakerClustering( 149 | self.config.tau_active, 150 | self.config.rho_update, 151 | self.config.delta_new, 152 | "cosine", 153 | self.config.max_speakers, 154 | ) 155 | self.chunk_buffer, self.pred_buffer = [], [] 156 | 157 | def __call__( 158 | self, waveforms: Sequence[SlidingWindowFeature] 159 | ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: 160 | """Diarize the next audio chunks of an audio stream. 161 | 162 | Parameters 163 | ---------- 164 | waveforms: Sequence[SlidingWindowFeature] 165 | A sequence of consecutive audio chunks from an audio stream. 166 | 167 | Returns 168 | ------- 169 | Sequence[tuple[Annotation, SlidingWindowFeature]] 170 | Speaker diarization of each chunk alongside their corresponding audio. 171 | """ 172 | batch_size = len(waveforms) 173 | msg = "Pipeline expected at least 1 input" 174 | assert batch_size >= 1, msg 175 | 176 | # Create batch from chunk sequence, shape (batch, samples, channels) 177 | batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) 178 | 179 | expected_num_samples = int( 180 | np.rint(self.config.duration * self.config.sample_rate) 181 | ) 182 | msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" 183 | assert batch.shape[1] == expected_num_samples, msg 184 | 185 | # Extract segmentation and embeddings 186 | segmentations = self.segmentation(batch) # shape (batch, frames, speakers) 187 | # embeddings has shape (batch, speakers, emb_dim) 188 | embeddings = self.embedding(batch, segmentations) 189 | 190 | seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] 191 | 192 | outputs = [] 193 | for wav, seg, emb in zip(waveforms, segmentations, embeddings): 194 | # Add timestamps to segmentation 195 | sw = SlidingWindow( 196 | start=wav.extent.start, 197 | duration=seg_resolution, 198 | step=seg_resolution, 199 | ) 200 | seg = SlidingWindowFeature(seg.cpu().numpy(), sw) 201 | 202 | # Update clustering state and permute segmentation 203 | permuted_seg = self.clustering(seg, emb) 204 | 205 | # Update sliding buffer 206 | self.chunk_buffer.append(wav) 207 | self.pred_buffer.append(permuted_seg) 208 | 209 | # Aggregate buffer outputs for this time step 210 | agg_waveform = self.audio_aggregation(self.chunk_buffer) 211 | agg_prediction = self.pred_aggregation(self.pred_buffer) 212 | agg_prediction = self.binarize(agg_prediction) 213 | 214 | # Shift prediction timestamps if required 215 | if self.timestamp_shift != 0: 216 | shifted_agg_prediction = Annotation(agg_prediction.uri) 217 | for segment, track, speaker in agg_prediction.itertracks( 218 | yield_label=True 219 | ): 220 | new_segment = Segment( 221 | segment.start + self.timestamp_shift, 222 | segment.end + self.timestamp_shift, 223 | ) 224 | shifted_agg_prediction[new_segment, track] = speaker 225 | agg_prediction = shifted_agg_prediction 226 | 227 | outputs.append((agg_prediction, agg_waveform)) 228 | 229 | # Make place for new chunks in buffer if required 230 | if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: 231 | self.chunk_buffer = self.chunk_buffer[1:] 232 | self.pred_buffer = self.pred_buffer[1:] 233 | 234 | return outputs 235 | -------------------------------------------------------------------------------- /src/diart/blocks/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Text 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from .. import functional as F 7 | from ..features import TemporalFeatures, TemporalFeatureFormatter 8 | from ..models import EmbeddingModel 9 | 10 | 11 | class SpeakerEmbedding: 12 | def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None): 13 | self.model = model 14 | self.model.eval() 15 | self.device = device 16 | if self.device is None: 17 | self.device = torch.device("cpu") 18 | self.model.to(self.device) 19 | self.waveform_formatter = TemporalFeatureFormatter() 20 | self.weights_formatter = TemporalFeatureFormatter() 21 | 22 | @staticmethod 23 | def from_pretrained( 24 | model, 25 | use_hf_token: Union[Text, bool, None] = True, 26 | device: Optional[torch.device] = None, 27 | ) -> "SpeakerEmbedding": 28 | emb_model = EmbeddingModel.from_pretrained(model, use_hf_token) 29 | return SpeakerEmbedding(emb_model, device) 30 | 31 | def __call__( 32 | self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None 33 | ) -> torch.Tensor: 34 | """ 35 | Calculate speaker embeddings of input audio. 36 | If weights are given, calculate many speaker embeddings from the same waveform. 37 | 38 | Parameters 39 | ---------- 40 | waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) 41 | weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers) 42 | Per-speaker and per-frame weights. Defaults to no weights. 43 | 44 | Returns 45 | ------- 46 | embeddings: torch.Tensor 47 | If weights are provided, the shape is (batch, speakers, embedding_dim), 48 | otherwise the shape is (batch, embedding_dim). 49 | If batch size == 1, the batch dimension is omitted. 50 | """ 51 | with torch.no_grad(): 52 | inputs = self.waveform_formatter.cast(waveform).to(self.device) 53 | inputs = rearrange(inputs, "batch sample channel -> batch channel sample") 54 | if weights is not None: 55 | weights = self.weights_formatter.cast(weights).to(self.device) 56 | batch_size, _, num_speakers = weights.shape 57 | inputs = inputs.repeat(1, num_speakers, 1) 58 | weights = rearrange(weights, "batch frame spk -> (batch spk) frame") 59 | inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample") 60 | output = rearrange( 61 | self.model(inputs, weights), 62 | "(batch spk) feat -> batch spk feat", 63 | batch=batch_size, 64 | spk=num_speakers, 65 | ) 66 | else: 67 | output = self.model(inputs) 68 | return output.squeeze().cpu() 69 | 70 | 71 | class OverlappedSpeechPenalty: 72 | """Applies a penalty on overlapping speech and low-confidence regions to speaker segmentation scores. 73 | 74 | .. note:: 75 | For more information, see `"Overlap-Aware Low-Latency Online Speaker Diarization 76 | based on End-to-End Local Segmentation" `_ 77 | (Section 2.2.1 Segmentation-driven speaker embedding). This block implements Equation 2. 78 | 79 | Parameters 80 | ---------- 81 | gamma: float, optional 82 | Exponent to lower low-confidence predictions. 83 | Defaults to 3. 84 | beta: float, optional 85 | Temperature parameter (actually 1/beta) to lower joint speaker activations. 86 | Defaults to 10. 87 | normalize: bool, optional 88 | Whether to min-max normalize weights to be in the range [0, 1]. 89 | Defaults to False. 90 | """ 91 | 92 | def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False): 93 | self.gamma = gamma 94 | self.beta = beta 95 | self.formatter = TemporalFeatureFormatter() 96 | self.normalize = normalize 97 | 98 | def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures: 99 | weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers) 100 | with torch.inference_mode(): 101 | weights = F.overlapped_speech_penalty(weights, self.gamma, self.beta) 102 | if self.normalize: 103 | min_values = weights.min(dim=1, keepdim=True).values 104 | max_values = weights.max(dim=1, keepdim=True).values 105 | weights = (weights - min_values) / (max_values - min_values) 106 | weights.nan_to_num_(1e-8) 107 | return self.formatter.restore_type(weights) 108 | 109 | 110 | class EmbeddingNormalization: 111 | def __init__(self, norm: Union[float, torch.Tensor] = 1): 112 | self.norm = norm 113 | # Add batch dimension if missing 114 | if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2: 115 | self.norm = self.norm.unsqueeze(0) 116 | 117 | def __call__(self, embeddings: torch.Tensor) -> torch.Tensor: 118 | with torch.inference_mode(): 119 | norm_embs = F.normalize_embeddings(embeddings, self.norm) 120 | return norm_embs 121 | 122 | 123 | class OverlapAwareSpeakerEmbedding: 124 | """ 125 | Extract overlap-aware speaker embeddings given an audio chunk and its segmentation. 126 | 127 | Parameters 128 | ---------- 129 | model: EmbeddingModel 130 | A pre-trained embedding model. 131 | gamma: float, optional 132 | Exponent to lower low-confidence predictions. 133 | Defaults to 3. 134 | beta: float, optional 135 | Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations. 136 | Defaults to 10. 137 | norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional 138 | The target norm for the embeddings. It can be different for each speaker. 139 | Defaults to 1. 140 | normalize_weights: bool, optional 141 | Whether to min-max normalize embedding weights to be in the range [0, 1]. 142 | device: Optional[torch.device] 143 | The device on which to run the embedding model. 144 | Defaults to GPU if available or CPU if not. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | model: EmbeddingModel, 150 | gamma: float = 3, 151 | beta: float = 10, 152 | norm: Union[float, torch.Tensor] = 1, 153 | normalize_weights: bool = False, 154 | device: Optional[torch.device] = None, 155 | ): 156 | self.embedding = SpeakerEmbedding(model, device) 157 | self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights) 158 | self.normalize = EmbeddingNormalization(norm) 159 | 160 | @staticmethod 161 | def from_pretrained( 162 | model, 163 | gamma: float = 3, 164 | beta: float = 10, 165 | norm: Union[float, torch.Tensor] = 1, 166 | use_hf_token: Union[Text, bool, None] = True, 167 | normalize_weights: bool = False, 168 | device: Optional[torch.device] = None, 169 | ): 170 | model = EmbeddingModel.from_pretrained(model, use_hf_token) 171 | return OverlapAwareSpeakerEmbedding( 172 | model, gamma, beta, norm, normalize_weights, device 173 | ) 174 | 175 | def __call__( 176 | self, waveform: TemporalFeatures, segmentation: TemporalFeatures 177 | ) -> torch.Tensor: 178 | return self.normalize(self.embedding(waveform, self.osp(segmentation))) 179 | -------------------------------------------------------------------------------- /src/diart/blocks/segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Text 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from ..features import TemporalFeatures, TemporalFeatureFormatter 7 | from ..models import SegmentationModel 8 | 9 | 10 | class SpeakerSegmentation: 11 | def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None): 12 | self.model = model 13 | self.model.eval() 14 | self.device = device 15 | if self.device is None: 16 | self.device = torch.device("cpu") 17 | self.model.to(self.device) 18 | self.formatter = TemporalFeatureFormatter() 19 | 20 | @staticmethod 21 | def from_pretrained( 22 | model, 23 | use_hf_token: Union[Text, bool, None] = True, 24 | device: Optional[torch.device] = None, 25 | ) -> "SpeakerSegmentation": 26 | seg_model = SegmentationModel.from_pretrained(model, use_hf_token) 27 | return SpeakerSegmentation(seg_model, device) 28 | 29 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: 30 | """ 31 | Calculate the speaker segmentation of input audio. 32 | 33 | Parameters 34 | ---------- 35 | waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels) 36 | 37 | Returns 38 | ------- 39 | speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers) 40 | The batch dimension is omitted if waveform is a `SlidingWindowFeature`. 41 | """ 42 | with torch.no_grad(): 43 | wave = rearrange( 44 | self.formatter.cast(waveform), 45 | "batch sample channel -> batch channel sample", 46 | ) 47 | output = self.model(wave.to(self.device)).cpu() 48 | return self.formatter.restore_type(output) 49 | -------------------------------------------------------------------------------- /src/diart/blocks/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Text, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from pyannote.core import Annotation, Segment, SlidingWindowFeature 6 | import torchaudio.transforms as T 7 | 8 | from ..features import TemporalFeatures, TemporalFeatureFormatter 9 | 10 | 11 | class Binarize: 12 | """ 13 | Transform a speaker segmentation from the discrete-time domain 14 | into a continuous-time speaker segmentation. 15 | 16 | Parameters 17 | ---------- 18 | threshold: float 19 | Probability threshold to determine if a speaker is active at a given frame. 20 | uri: Optional[Text] 21 | Uri of the audio stream. Defaults to no uri. 22 | """ 23 | 24 | def __init__(self, threshold: float, uri: Optional[Text] = None): 25 | self.uri = uri 26 | self.threshold = threshold 27 | 28 | def __call__(self, segmentation: SlidingWindowFeature) -> Annotation: 29 | """ 30 | Return the continuous-time segmentation 31 | corresponding to the discrete-time input segmentation. 32 | 33 | Parameters 34 | ---------- 35 | segmentation: SlidingWindowFeature 36 | Discrete-time speaker segmentation. 37 | 38 | Returns 39 | ------- 40 | annotation: Annotation 41 | Continuous-time speaker segmentation. 42 | """ 43 | num_frames, num_speakers = segmentation.data.shape 44 | timestamps = segmentation.sliding_window 45 | is_active = segmentation.data > self.threshold 46 | # Artificially add last inactive frame to close any remaining speaker turns 47 | is_active = np.append(is_active, [[False] * num_speakers], axis=0) 48 | start_times = np.zeros(num_speakers) + timestamps[0].middle 49 | annotation = Annotation(uri=self.uri, modality="speech") 50 | for t in range(num_frames): 51 | # Any (False, True) starts a speaker turn at "True" index 52 | onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1]) 53 | start_times[onsets] = timestamps[t + 1].middle 54 | # Any (True, False) ends a speaker turn at "False" index 55 | offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1])) 56 | for spk in np.where(offsets)[0]: 57 | region = Segment(start_times[spk], timestamps[t + 1].middle) 58 | annotation[region, spk] = f"speaker{spk}" 59 | return annotation 60 | 61 | 62 | class Resample: 63 | """Dynamically resample audio chunks. 64 | 65 | Parameters 66 | ---------- 67 | sample_rate: int 68 | Original sample rate of the input audio 69 | resample_rate: int 70 | Sample rate of the output 71 | """ 72 | 73 | def __init__( 74 | self, 75 | sample_rate: int, 76 | resample_rate: int, 77 | device: Optional[torch.device] = None, 78 | ): 79 | self.device = device 80 | if self.device is None: 81 | self.device = torch.device("cpu") 82 | self.resample = T.Resample(sample_rate, resample_rate).to(self.device) 83 | self.formatter = TemporalFeatureFormatter() 84 | 85 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: 86 | wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1) 87 | with torch.no_grad(): 88 | resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2) 89 | return self.formatter.restore_type(resampled_wav) 90 | 91 | 92 | class AdjustVolume: 93 | """Change the volume of an audio chunk. 94 | 95 | Notice that the output volume might be different to avoid saturation. 96 | 97 | Parameters 98 | ---------- 99 | volume_in_db: float 100 | Target volume in dB. 101 | """ 102 | 103 | def __init__(self, volume_in_db: float): 104 | self.target_db = volume_in_db 105 | self.formatter = TemporalFeatureFormatter() 106 | 107 | @staticmethod 108 | def get_volumes(waveforms: torch.Tensor) -> torch.Tensor: 109 | """Compute the volumes of a set of audio chunks. 110 | 111 | Parameters 112 | ---------- 113 | waveforms: torch.Tensor 114 | Audio chunks. Shape (batch, samples, channels). 115 | 116 | Returns 117 | ------- 118 | volumes: torch.Tensor 119 | Audio chunk volumes per channel. Shape (batch, 1, channels) 120 | """ 121 | return 10 * torch.log10( 122 | torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True) 123 | ) 124 | 125 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures: 126 | wav = self.formatter.cast(waveform) # shape (batch, samples, channels) 127 | with torch.no_grad(): 128 | # Compute current volume per chunk, shape (batch, 1, channels) 129 | current_volumes = self.get_volumes(wav) 130 | # Determine gain to reach the target volume 131 | gains = 10 ** ((self.target_db - current_volumes) / 20) 132 | # Apply gain 133 | wav = gains * wav 134 | # If maximum value is greater than one, normalize chunk 135 | maximums = torch.clamp(torch.amax(torch.abs(wav), dim=1, keepdim=True), 1) 136 | wav = wav / maximums 137 | return self.formatter.restore_type(wav) 138 | -------------------------------------------------------------------------------- /src/diart/blocks/vad.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Sequence 4 | 5 | import numpy as np 6 | import torch 7 | from pyannote.core import ( 8 | Annotation, 9 | Timeline, 10 | SlidingWindowFeature, 11 | SlidingWindow, 12 | Segment, 13 | ) 14 | from pyannote.metrics.base import BaseMetric 15 | from pyannote.metrics.detection import DetectionErrorRate 16 | from typing_extensions import Literal 17 | 18 | from . import base 19 | from .aggregation import DelayedAggregation 20 | from .segmentation import SpeakerSegmentation 21 | from .utils import Binarize 22 | from .. import models as m 23 | from .. import utils 24 | 25 | 26 | class VoiceActivityDetectionConfig(base.PipelineConfig): 27 | def __init__( 28 | self, 29 | segmentation: m.SegmentationModel | None = None, 30 | duration: float = 5, 31 | step: float = 0.5, 32 | latency: float | Literal["max", "min"] | None = None, 33 | tau_active: float = 0.6, 34 | device: torch.device | None = None, 35 | sample_rate: int = 16000, 36 | **kwargs, 37 | ): 38 | # Default segmentation model is pyannote/segmentation 39 | self.segmentation = segmentation or m.SegmentationModel.from_pyannote( 40 | "pyannote/segmentation" 41 | ) 42 | 43 | self._duration = duration 44 | self._step = step 45 | self._sample_rate = sample_rate 46 | 47 | # Latency defaults to the step duration 48 | self._latency = latency 49 | if self._latency is None or self._latency == "min": 50 | self._latency = self._step 51 | elif self._latency == "max": 52 | self._latency = self._duration 53 | 54 | self.tau_active = tau_active 55 | self.device = device or torch.device( 56 | "cuda" if torch.cuda.is_available() else "cpu" 57 | ) 58 | 59 | @property 60 | def duration(self) -> float: 61 | return self._duration 62 | 63 | @property 64 | def step(self) -> float: 65 | return self._step 66 | 67 | @property 68 | def latency(self) -> float: 69 | return self._latency 70 | 71 | @property 72 | def sample_rate(self) -> int: 73 | return self._sample_rate 74 | 75 | 76 | class VoiceActivityDetection(base.Pipeline): 77 | def __init__(self, config: VoiceActivityDetectionConfig | None = None): 78 | self._config = VoiceActivityDetectionConfig() if config is None else config 79 | 80 | msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]" 81 | assert self._config.step <= self._config.latency <= self._config.duration, msg 82 | 83 | self.segmentation = SpeakerSegmentation( 84 | self._config.segmentation, self._config.device 85 | ) 86 | self.pred_aggregation = DelayedAggregation( 87 | self._config.step, 88 | self._config.latency, 89 | strategy="hamming", 90 | cropping_mode="loose", 91 | ) 92 | self.audio_aggregation = DelayedAggregation( 93 | self._config.step, 94 | self._config.latency, 95 | strategy="first", 96 | cropping_mode="center", 97 | ) 98 | self.binarize = Binarize(self._config.tau_active) 99 | 100 | # Internal state, handle with care 101 | self.timestamp_shift = 0 102 | self.chunk_buffer, self.pred_buffer = [], [] 103 | 104 | @staticmethod 105 | def get_config_class() -> type: 106 | return VoiceActivityDetectionConfig 107 | 108 | @staticmethod 109 | def suggest_metric() -> BaseMetric: 110 | return DetectionErrorRate(collar=0, skip_overlap=False) 111 | 112 | @staticmethod 113 | def hyper_parameters() -> Sequence[base.HyperParameter]: 114 | return [base.TauActive] 115 | 116 | @property 117 | def config(self) -> base.PipelineConfig: 118 | return self._config 119 | 120 | def reset(self): 121 | self.set_timestamp_shift(0) 122 | self.chunk_buffer, self.pred_buffer = [], [] 123 | 124 | def set_timestamp_shift(self, shift: float): 125 | self.timestamp_shift = shift 126 | 127 | def __call__( 128 | self, 129 | waveforms: Sequence[SlidingWindowFeature], 130 | ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]: 131 | batch_size = len(waveforms) 132 | msg = "Pipeline expected at least 1 input" 133 | assert batch_size >= 1, msg 134 | 135 | # Create batch from chunk sequence, shape (batch, samples, channels) 136 | batch = torch.stack([torch.from_numpy(w.data) for w in waveforms]) 137 | 138 | expected_num_samples = int( 139 | np.rint(self.config.duration * self.config.sample_rate) 140 | ) 141 | msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}" 142 | assert batch.shape[1] == expected_num_samples, msg 143 | 144 | # Extract segmentation 145 | segmentations = self.segmentation(batch) # shape (batch, frames, speakers) 146 | voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[ 147 | 0 148 | ] # shape (batch, frames, 1) 149 | 150 | seg_resolution = waveforms[0].extent.duration / segmentations.shape[1] 151 | 152 | outputs = [] 153 | for wav, vad in zip(waveforms, voice_detection): 154 | # Add timestamps to segmentation 155 | sw = SlidingWindow( 156 | start=wav.extent.start, 157 | duration=seg_resolution, 158 | step=seg_resolution, 159 | ) 160 | vad = SlidingWindowFeature(vad.cpu().numpy(), sw) 161 | 162 | # Update sliding buffer 163 | self.chunk_buffer.append(wav) 164 | self.pred_buffer.append(vad) 165 | 166 | # Aggregate buffer outputs for this time step 167 | agg_waveform = self.audio_aggregation(self.chunk_buffer) 168 | agg_prediction = self.pred_aggregation(self.pred_buffer) 169 | agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False) 170 | 171 | # Shift prediction timestamps if required 172 | if self.timestamp_shift != 0: 173 | shifted_agg_prediction = Timeline(uri=agg_prediction.uri) 174 | for segment in agg_prediction: 175 | new_segment = Segment( 176 | segment.start + self.timestamp_shift, 177 | segment.end + self.timestamp_shift, 178 | ) 179 | shifted_agg_prediction.add(new_segment) 180 | agg_prediction = shifted_agg_prediction 181 | 182 | # Convert timeline into annotation with single speaker "speech" 183 | agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech")) 184 | outputs.append((agg_prediction, agg_waveform)) 185 | 186 | # Make place for new chunks in buffer if required 187 | if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows: 188 | self.chunk_buffer = self.chunk_buffer[1:] 189 | self.pred_buffer = self.pred_buffer[1:] 190 | 191 | return outputs 192 | -------------------------------------------------------------------------------- /src/diart/console/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/src/diart/console/__init__.py -------------------------------------------------------------------------------- /src/diart/console/benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | import torch 6 | 7 | from diart import argdoc 8 | from diart import models as m 9 | from diart import utils 10 | from diart.inference import Benchmark, Parallelize 11 | 12 | 13 | def run(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "root", 17 | type=Path, 18 | help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", 19 | ) 20 | parser.add_argument( 21 | "--pipeline", 22 | default="SpeakerDiarization", 23 | type=str, 24 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", 25 | ) 26 | parser.add_argument( 27 | "--segmentation", 28 | default="pyannote/segmentation", 29 | type=str, 30 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", 31 | ) 32 | parser.add_argument( 33 | "--embedding", 34 | default="pyannote/embedding", 35 | type=str, 36 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", 37 | ) 38 | parser.add_argument( 39 | "--reference", 40 | type=Path, 41 | help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files", 42 | ) 43 | parser.add_argument( 44 | "--duration", 45 | type=float, 46 | default=5, 47 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration", 48 | ) 49 | parser.add_argument( 50 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" 51 | ) 52 | parser.add_argument( 53 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" 54 | ) 55 | parser.add_argument( 56 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" 57 | ) 58 | parser.add_argument( 59 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" 60 | ) 61 | parser.add_argument( 62 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" 63 | ) 64 | parser.add_argument( 65 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" 66 | ) 67 | parser.add_argument( 68 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" 69 | ) 70 | parser.add_argument( 71 | "--max-speakers", 72 | default=20, 73 | type=int, 74 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", 75 | ) 76 | parser.add_argument( 77 | "--batch-size", 78 | default=32, 79 | type=int, 80 | help=f"{argdoc.BATCH_SIZE}. Defaults to 32", 81 | ) 82 | parser.add_argument( 83 | "--num-workers", 84 | default=0, 85 | type=int, 86 | help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)", 87 | ) 88 | parser.add_argument( 89 | "--cpu", 90 | dest="cpu", 91 | action="store_true", 92 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", 93 | ) 94 | parser.add_argument( 95 | "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" 96 | ) 97 | parser.add_argument( 98 | "--hf-token", 99 | default="true", 100 | type=str, 101 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", 102 | ) 103 | parser.add_argument( 104 | "--normalize-embedding-weights", 105 | action="store_true", 106 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", 107 | ) 108 | args = parser.parse_args() 109 | 110 | # Resolve device 111 | args.device = torch.device("cpu") if args.cpu else None 112 | 113 | # Resolve models 114 | hf_token = utils.parse_hf_token_arg(args.hf_token) 115 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token) 116 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token) 117 | 118 | pipeline_class = utils.get_pipeline_class(args.pipeline) 119 | 120 | benchmark = Benchmark( 121 | args.root, 122 | args.reference, 123 | args.output, 124 | show_progress=True, 125 | show_report=True, 126 | batch_size=args.batch_size, 127 | ) 128 | 129 | config = pipeline_class.get_config_class()(**vars(args)) 130 | if args.num_workers > 0: 131 | benchmark = Parallelize(benchmark, args.num_workers) 132 | 133 | report = benchmark(pipeline_class, config) 134 | 135 | if args.output is not None and isinstance(report, pd.DataFrame): 136 | report.to_csv(args.output / "benchmark_report.csv") 137 | 138 | 139 | if __name__ == "__main__": 140 | run() 141 | -------------------------------------------------------------------------------- /src/diart/console/client.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from threading import Thread 4 | from typing import Text, Optional 5 | 6 | import rx.operators as ops 7 | from websocket import WebSocket 8 | 9 | from diart import argdoc 10 | from diart import sources as src 11 | from diart import utils 12 | 13 | 14 | def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int): 15 | # Create audio source 16 | source_components = source.split(":") 17 | if source_components[0] != "microphone": 18 | audio_source = src.FileAudioSource(source, sample_rate, block_duration=step) 19 | else: 20 | device = int(source_components[1]) if len(source_components) > 1 else None 21 | audio_source = src.MicrophoneAudioSource(step, device) 22 | 23 | # Encode audio, then send through websocket 24 | audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send) 25 | 26 | # Start reading audio 27 | audio_source.read() 28 | 29 | 30 | def receive_audio(ws: WebSocket, output: Optional[Path]): 31 | while True: 32 | message = ws.recv() 33 | print(f"Received: {message}", end="") 34 | if output is not None: 35 | with open(output, "a") as file: 36 | file.write(message) 37 | 38 | 39 | def run(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | "source", 43 | type=str, 44 | help="Path to an audio file | 'microphone' | 'microphone:'", 45 | ) 46 | parser.add_argument("--host", required=True, type=str, help="Server host") 47 | parser.add_argument("--port", required=True, type=int, help="Server port") 48 | parser.add_argument( 49 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" 50 | ) 51 | parser.add_argument( 52 | "-sr", 53 | "--sample-rate", 54 | default=16000, 55 | type=int, 56 | help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000", 57 | ) 58 | parser.add_argument( 59 | "-o", 60 | "--output-file", 61 | type=Path, 62 | help="Output RTTM file. Defaults to no writing", 63 | ) 64 | args = parser.parse_args() 65 | 66 | # Run websocket client 67 | ws = WebSocket() 68 | ws.connect(f"ws://{args.host}:{args.port}") 69 | sender = Thread( 70 | target=send_audio, args=[ws, args.source, args.step, args.sample_rate] 71 | ) 72 | receiver = Thread(target=receive_audio, args=[ws, args.output_file]) 73 | sender.start() 74 | receiver.start() 75 | 76 | 77 | if __name__ == "__main__": 78 | run() 79 | -------------------------------------------------------------------------------- /src/diart/console/serve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from diart import argdoc 7 | from diart import models as m 8 | from diart import sources as src 9 | from diart import utils 10 | from diart.inference import StreamingInference 11 | from diart.sinks import RTTMWriter 12 | 13 | 14 | def run(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host") 17 | parser.add_argument("--port", default=7007, type=int, help="Server port") 18 | parser.add_argument( 19 | "--pipeline", 20 | default="SpeakerDiarization", 21 | type=str, 22 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", 23 | ) 24 | parser.add_argument( 25 | "--segmentation", 26 | default="pyannote/segmentation", 27 | type=str, 28 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", 29 | ) 30 | parser.add_argument( 31 | "--embedding", 32 | default="pyannote/embedding", 33 | type=str, 34 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", 35 | ) 36 | parser.add_argument( 37 | "--duration", 38 | type=float, 39 | default=5, 40 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration", 41 | ) 42 | parser.add_argument( 43 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" 44 | ) 45 | parser.add_argument( 46 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" 47 | ) 48 | parser.add_argument( 49 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" 50 | ) 51 | parser.add_argument( 52 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" 53 | ) 54 | parser.add_argument( 55 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" 56 | ) 57 | parser.add_argument( 58 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" 59 | ) 60 | parser.add_argument( 61 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" 62 | ) 63 | parser.add_argument( 64 | "--max-speakers", 65 | default=20, 66 | type=int, 67 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", 68 | ) 69 | parser.add_argument( 70 | "--cpu", 71 | dest="cpu", 72 | action="store_true", 73 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", 74 | ) 75 | parser.add_argument( 76 | "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing" 77 | ) 78 | parser.add_argument( 79 | "--hf-token", 80 | default="true", 81 | type=str, 82 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", 83 | ) 84 | parser.add_argument( 85 | "--normalize-embedding-weights", 86 | action="store_true", 87 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", 88 | ) 89 | args = parser.parse_args() 90 | 91 | # Resolve device 92 | args.device = torch.device("cpu") if args.cpu else None 93 | 94 | # Resolve models 95 | hf_token = utils.parse_hf_token_arg(args.hf_token) 96 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token) 97 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token) 98 | 99 | # Resolve pipeline 100 | pipeline_class = utils.get_pipeline_class(args.pipeline) 101 | config = pipeline_class.get_config_class()(**vars(args)) 102 | pipeline = pipeline_class(config) 103 | 104 | # Create websocket audio source 105 | audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port) 106 | 107 | # Run online inference 108 | inference = StreamingInference( 109 | pipeline, 110 | audio_source, 111 | batch_size=1, 112 | do_profile=False, 113 | do_plot=False, 114 | show_progress=True, 115 | ) 116 | 117 | # Write to disk if required 118 | if args.output is not None: 119 | inference.attach_observers( 120 | RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") 121 | ) 122 | 123 | # Send back responses as RTTM text lines 124 | inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm())) 125 | 126 | # Run server and pipeline 127 | inference() 128 | 129 | 130 | if __name__ == "__main__": 131 | run() 132 | -------------------------------------------------------------------------------- /src/diart/console/stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from diart import argdoc 7 | from diart import models as m 8 | from diart import sources as src 9 | from diart import utils 10 | from diart.inference import StreamingInference 11 | from diart.sinks import RTTMWriter 12 | 13 | 14 | def run(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "source", 18 | type=str, 19 | help="Path to an audio file | 'microphone' | 'microphone:'", 20 | ) 21 | parser.add_argument( 22 | "--pipeline", 23 | default="SpeakerDiarization", 24 | type=str, 25 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", 26 | ) 27 | parser.add_argument( 28 | "--segmentation", 29 | default="pyannote/segmentation", 30 | type=str, 31 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", 32 | ) 33 | parser.add_argument( 34 | "--embedding", 35 | default="pyannote/embedding", 36 | type=str, 37 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", 38 | ) 39 | parser.add_argument( 40 | "--duration", 41 | type=float, 42 | default=5, 43 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration", 44 | ) 45 | parser.add_argument( 46 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" 47 | ) 48 | parser.add_argument( 49 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" 50 | ) 51 | parser.add_argument( 52 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" 53 | ) 54 | parser.add_argument( 55 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" 56 | ) 57 | parser.add_argument( 58 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" 59 | ) 60 | parser.add_argument( 61 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" 62 | ) 63 | parser.add_argument( 64 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" 65 | ) 66 | parser.add_argument( 67 | "--max-speakers", 68 | default=20, 69 | type=int, 70 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", 71 | ) 72 | parser.add_argument( 73 | "--no-plot", 74 | dest="no_plot", 75 | action="store_true", 76 | help="Skip plotting for faster inference", 77 | ) 78 | parser.add_argument( 79 | "--cpu", 80 | dest="cpu", 81 | action="store_true", 82 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", 83 | ) 84 | parser.add_argument( 85 | "--output", 86 | type=str, 87 | help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file", 88 | ) 89 | parser.add_argument( 90 | "--hf-token", 91 | default="true", 92 | type=str, 93 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", 94 | ) 95 | parser.add_argument( 96 | "--normalize-embedding-weights", 97 | action="store_true", 98 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", 99 | ) 100 | args = parser.parse_args() 101 | 102 | # Resolve device 103 | args.device = torch.device("cpu") if args.cpu else None 104 | 105 | # Resolve models 106 | hf_token = utils.parse_hf_token_arg(args.hf_token) 107 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token) 108 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token) 109 | 110 | # Resolve pipeline 111 | pipeline_class = utils.get_pipeline_class(args.pipeline) 112 | config = pipeline_class.get_config_class()(**vars(args)) 113 | pipeline = pipeline_class(config) 114 | 115 | # Manage audio source 116 | source_components = args.source.split(":") 117 | if source_components[0] != "microphone": 118 | args.source = Path(args.source).expanduser() 119 | args.output = args.source.parent if args.output is None else Path(args.output) 120 | padding = config.get_file_padding(args.source) 121 | audio_source = src.FileAudioSource( 122 | args.source, config.sample_rate, padding, config.step 123 | ) 124 | pipeline.set_timestamp_shift(-padding[0]) 125 | else: 126 | args.output = ( 127 | Path("~/").expanduser() if args.output is None else Path(args.output) 128 | ) 129 | device = int(source_components[1]) if len(source_components) > 1 else None 130 | audio_source = src.MicrophoneAudioSource(config.step, device) 131 | 132 | # Run online inference 133 | inference = StreamingInference( 134 | pipeline, 135 | audio_source, 136 | batch_size=1, 137 | do_profile=True, 138 | do_plot=not args.no_plot, 139 | show_progress=True, 140 | ) 141 | inference.attach_observers( 142 | RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm") 143 | ) 144 | try: 145 | inference() 146 | except KeyboardInterrupt: 147 | pass 148 | 149 | 150 | if __name__ == "__main__": 151 | run() 152 | -------------------------------------------------------------------------------- /src/diart/console/tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import optuna 5 | import torch 6 | from optuna.samplers import TPESampler 7 | 8 | from diart import argdoc 9 | from diart import models as m 10 | from diart import utils 11 | from diart.blocks.base import HyperParameter 12 | from diart.optim import Optimizer 13 | 14 | 15 | def run(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "root", 19 | type=str, 20 | help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)", 21 | ) 22 | parser.add_argument( 23 | "--reference", 24 | required=True, 25 | type=str, 26 | help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files", 27 | ) 28 | parser.add_argument( 29 | "--pipeline", 30 | default="SpeakerDiarization", 31 | type=str, 32 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'", 33 | ) 34 | parser.add_argument( 35 | "--segmentation", 36 | default="pyannote/segmentation", 37 | type=str, 38 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation", 39 | ) 40 | parser.add_argument( 41 | "--embedding", 42 | default="pyannote/embedding", 43 | type=str, 44 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding", 45 | ) 46 | parser.add_argument( 47 | "--duration", 48 | type=float, 49 | default=5, 50 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration", 51 | ) 52 | parser.add_argument( 53 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5" 54 | ) 55 | parser.add_argument( 56 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5" 57 | ) 58 | parser.add_argument( 59 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5" 60 | ) 61 | parser.add_argument( 62 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3" 63 | ) 64 | parser.add_argument( 65 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1" 66 | ) 67 | parser.add_argument( 68 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3" 69 | ) 70 | parser.add_argument( 71 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10" 72 | ) 73 | parser.add_argument( 74 | "--max-speakers", 75 | default=20, 76 | type=int, 77 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20", 78 | ) 79 | parser.add_argument( 80 | "--batch-size", 81 | default=32, 82 | type=int, 83 | help=f"{argdoc.BATCH_SIZE}. Defaults to 32", 84 | ) 85 | parser.add_argument( 86 | "--cpu", 87 | dest="cpu", 88 | action="store_true", 89 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise", 90 | ) 91 | parser.add_argument( 92 | "--hparams", 93 | nargs="+", 94 | default=("tau_active", "rho_update", "delta_new"), 95 | help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new", 96 | ) 97 | parser.add_argument( 98 | "--num-iter", default=100, type=int, help="Number of optimization trials" 99 | ) 100 | parser.add_argument( 101 | "--storage", 102 | type=str, 103 | help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name", 104 | ) 105 | parser.add_argument("--output", type=str, help="Working directory") 106 | parser.add_argument( 107 | "--hf-token", 108 | default="true", 109 | type=str, 110 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)", 111 | ) 112 | parser.add_argument( 113 | "--normalize-embedding-weights", 114 | action="store_true", 115 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False", 116 | ) 117 | args = parser.parse_args() 118 | 119 | # Resolve device 120 | args.device = torch.device("cpu") if args.cpu else None 121 | 122 | # Resolve models 123 | hf_token = utils.parse_hf_token_arg(args.hf_token) 124 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token) 125 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token) 126 | 127 | # Retrieve pipeline class 128 | pipeline_class = utils.get_pipeline_class(args.pipeline) 129 | 130 | # Create the base configuration for each trial 131 | base_config = pipeline_class.get_config_class()(**vars(args)) 132 | 133 | # Create hyper-parameters to optimize 134 | possible_hparams = pipeline_class.hyper_parameters() 135 | hparams = [HyperParameter.from_name(name) for name in args.hparams] 136 | hparams = [hp for hp in hparams if hp in possible_hparams] 137 | if not hparams: 138 | print( 139 | f"No hyper-parameters to optimize. " 140 | f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}" 141 | ) 142 | exit(1) 143 | 144 | # Use a custom storage if given 145 | if args.output is not None: 146 | msg = "Both `output` and `storage` were set, but only one was expected" 147 | assert args.storage is None, msg 148 | args.output = Path(args.output).expanduser() 149 | args.output.mkdir(parents=True, exist_ok=True) 150 | study_or_path = args.output 151 | elif args.storage is not None: 152 | db_name = Path(args.storage).stem 153 | study_or_path = optuna.load_study(db_name, args.storage, TPESampler()) 154 | else: 155 | msg = "Please provide either `output` or `storage`" 156 | raise ValueError(msg) 157 | 158 | # Run optimization 159 | Optimizer( 160 | pipeline_class=pipeline_class, 161 | speech_path=args.root, 162 | reference_path=args.reference, 163 | study_or_path=study_or_path, 164 | batch_size=args.batch_size, 165 | hparams=hparams, 166 | base_config=base_config, 167 | )(num_iter=args.num_iter, show_progress=True) 168 | 169 | 170 | if __name__ == "__main__": 171 | run() 172 | -------------------------------------------------------------------------------- /src/diart/features.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | from pyannote.core import SlidingWindow, SlidingWindowFeature 7 | 8 | TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor] 9 | 10 | 11 | class TemporalFeatureFormatterState(ABC): 12 | """ 13 | Represents the recorded type of a temporal feature formatter. 14 | Its job is to transform temporal features into tensors and 15 | recover the original format on other features. 16 | """ 17 | 18 | @abstractmethod 19 | def to_tensor(self, features: TemporalFeatures) -> torch.Tensor: 20 | pass 21 | 22 | @abstractmethod 23 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: 24 | """ 25 | Cast `features` to the representing type and remove batch dimension if required. 26 | 27 | Parameters 28 | ---------- 29 | features: torch.Tensor, shape (batch, frames, dim) 30 | Batched temporal features. 31 | Returns 32 | ------- 33 | new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) 34 | """ 35 | pass 36 | 37 | 38 | class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState): 39 | def __init__(self, duration: float): 40 | self.duration = duration 41 | self._cur_start_time = 0 42 | 43 | def to_tensor(self, features: SlidingWindowFeature) -> torch.Tensor: 44 | msg = "Features sliding window duration and step must be equal" 45 | assert features.sliding_window.duration == features.sliding_window.step, msg 46 | self._cur_start_time = features.sliding_window.start 47 | return torch.from_numpy(features.data) 48 | 49 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: 50 | batch_size, num_frames, _ = features.shape 51 | assert batch_size == 1, "Batched SlidingWindowFeature objects are not supported" 52 | # Calculate resolution 53 | resolution = self.duration / num_frames 54 | # Temporal shift to keep track of current start time 55 | resolution = SlidingWindow( 56 | start=self._cur_start_time, duration=resolution, step=resolution 57 | ) 58 | return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution) 59 | 60 | 61 | class NumpyArrayFormatterState(TemporalFeatureFormatterState): 62 | def to_tensor(self, features: np.ndarray) -> torch.Tensor: 63 | return torch.from_numpy(features) 64 | 65 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: 66 | return features.cpu().numpy() 67 | 68 | 69 | class PytorchTensorFormatterState(TemporalFeatureFormatterState): 70 | def to_tensor(self, features: torch.Tensor) -> torch.Tensor: 71 | return features 72 | 73 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures: 74 | return features 75 | 76 | 77 | class TemporalFeatureFormatter: 78 | """ 79 | Manages the typing and format of temporal features. 80 | When casting temporal features as torch.Tensor, it remembers its 81 | type and format so it can lately restore it on other temporal features. 82 | """ 83 | 84 | def __init__(self): 85 | self.state: Optional[TemporalFeatureFormatterState] = None 86 | 87 | def set_state(self, features: TemporalFeatures): 88 | if isinstance(features, SlidingWindowFeature): 89 | msg = "Features sliding window duration and step must be equal" 90 | assert features.sliding_window.duration == features.sliding_window.step, msg 91 | self.state = SlidingWindowFeatureFormatterState( 92 | features.data.shape[0] * features.sliding_window.duration, 93 | ) 94 | elif isinstance(features, np.ndarray): 95 | self.state = NumpyArrayFormatterState() 96 | elif isinstance(features, torch.Tensor): 97 | self.state = PytorchTensorFormatterState() 98 | else: 99 | msg = "Unknown format. Provide one of SlidingWindowFeature, numpy.ndarray, torch.Tensor" 100 | raise ValueError(msg) 101 | 102 | def cast(self, features: TemporalFeatures) -> torch.Tensor: 103 | """ 104 | Transform features into a `torch.Tensor` and add batch dimension if missing. 105 | 106 | Parameters 107 | ---------- 108 | features: SlidingWindowFeature or numpy.ndarray or torch.Tensor 109 | Shape (frames, dim) or (batch, frames, dim) 110 | 111 | Returns 112 | ------- 113 | features: torch.Tensor, shape (batch, frames, dim) 114 | """ 115 | # Set state if not initialized 116 | self.set_state(features) 117 | # Convert features to tensor 118 | data = self.state.to_tensor(features) 119 | # Make sure there's a batch dimension 120 | msg = "Temporal features must be 2D or 3D" 121 | assert data.ndim in (2, 3), msg 122 | if data.ndim == 2: 123 | data = data.unsqueeze(0) 124 | return data.float() 125 | 126 | def restore_type(self, features: torch.Tensor) -> TemporalFeatures: 127 | """ 128 | Cast `features` to the internal type and remove batch dimension if required. 129 | 130 | Parameters 131 | ---------- 132 | features: torch.Tensor, shape (batch, frames, dim) 133 | Batched temporal features. 134 | Returns 135 | ------- 136 | new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim) 137 | """ 138 | return self.state.to_internal_type(features) 139 | -------------------------------------------------------------------------------- /src/diart/functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | 6 | def overlapped_speech_penalty( 7 | segmentation: torch.Tensor, gamma: float = 3, beta: float = 10 8 | ): 9 | # segmentation has shape (batch, frames, speakers) 10 | probs = torch.softmax(beta * segmentation, dim=-1) 11 | weights = torch.pow(segmentation, gamma) * torch.pow(probs, gamma) 12 | weights[weights < 1e-8] = 1e-8 13 | return weights 14 | 15 | 16 | def normalize_embeddings( 17 | embeddings: torch.Tensor, norm: float | torch.Tensor = 1 18 | ) -> torch.Tensor: 19 | # embeddings has shape (batch, speakers, feat) or (speakers, feat) 20 | if embeddings.ndim == 2: 21 | embeddings = embeddings.unsqueeze(0) 22 | if isinstance(norm, torch.Tensor): 23 | batch_size1, num_speakers1, _ = norm.shape 24 | batch_size2, num_speakers2, _ = embeddings.shape 25 | assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2 26 | emb_norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True) 27 | return norm * embeddings / emb_norm 28 | -------------------------------------------------------------------------------- /src/diart/mapping.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable, Iterable, List, Optional, Text, Tuple, Union, Dict 4 | from abc import ABC, abstractmethod 5 | 6 | import numpy as np 7 | from pyannote.core.utils.distance import cdist 8 | from scipy.optimize import linear_sum_assignment as lsap 9 | 10 | 11 | class MappingMatrixObjective(ABC): 12 | def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray: 13 | return np.ones(shape) * self.invalid_value 14 | 15 | def optimal_assignments(self, matrix: np.ndarray) -> List[int]: 16 | return list(lsap(matrix, self.maximize)[1]) 17 | 18 | def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]: 19 | # Entries full of invalid_value are not mapped 20 | best_values = self.best_value_fn(matrix, axis=axis) 21 | return list(np.where(best_values != self.invalid_value)[0]) 22 | 23 | def hard_speaker_map( 24 | self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]] 25 | ) -> SpeakerMap: 26 | """Create a hard map object where the highest cost is put 27 | everywhere except on hard assignments from ``assignments``. 28 | 29 | Parameters 30 | ---------- 31 | num_src: int 32 | Number of source speakers 33 | num_tgt: int 34 | Number of target speakers 35 | assignments: Iterable[Tuple[int, int]] 36 | An iterable of tuples with two elements having the first element as the source speaker 37 | and the second element as the target speaker 38 | 39 | Returns 40 | ------- 41 | SpeakerMap 42 | """ 43 | mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt)) 44 | for src, tgt in assignments: 45 | mapping_matrix[src, tgt] = self.best_possible_value 46 | return SpeakerMap(mapping_matrix, self) 47 | 48 | @property 49 | def invalid_value(self) -> float: 50 | # linear_sum_assignment cannot deal with np.inf, 51 | # which would be ideal. Using a big number instead. 52 | return -1e10 if self.maximize else 1e10 53 | 54 | @property 55 | @abstractmethod 56 | def maximize(self) -> bool: 57 | pass 58 | 59 | @property 60 | @abstractmethod 61 | def best_possible_value(self) -> float: 62 | pass 63 | 64 | @property 65 | @abstractmethod 66 | def best_value_fn(self) -> Callable: 67 | pass 68 | 69 | 70 | class MinimizationObjective(MappingMatrixObjective): 71 | @property 72 | def maximize(self) -> bool: 73 | return False 74 | 75 | @property 76 | def best_possible_value(self) -> float: 77 | return 0 78 | 79 | @property 80 | def best_value_fn(self) -> Callable: 81 | return np.min 82 | 83 | 84 | class MaximizationObjective(MappingMatrixObjective): 85 | def __init__(self, max_value: float = 1): 86 | self.max_value = max_value 87 | 88 | @property 89 | def maximize(self) -> bool: 90 | return True 91 | 92 | @property 93 | def best_possible_value(self) -> float: 94 | return self.max_value 95 | 96 | @property 97 | def best_value_fn(self) -> Callable: 98 | return np.max 99 | 100 | 101 | class SpeakerMapBuilder: 102 | @staticmethod 103 | def hard_map( 104 | shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool 105 | ) -> SpeakerMap: 106 | """Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the 107 | highest cost is put everywhere except on hard assignments from ``assignments``. 108 | 109 | Parameters 110 | ---------- 111 | shape: Tuple[int, int]) 112 | Shape of the mapping matrix 113 | assignments: Iterable[Tuple[int, int]] 114 | An iterable of tuples with two elements having the first element as the source speaker 115 | and the second element as the target speaker 116 | maximize: bool 117 | whether to use scores where higher is better (true) or where lower is better (false) 118 | 119 | Returns 120 | ------- 121 | SpeakerMap 122 | """ 123 | num_src, num_tgt = shape 124 | objective = MaximizationObjective if maximize else MinimizationObjective 125 | return objective().hard_speaker_map(num_src, num_tgt, assignments) 126 | 127 | @staticmethod 128 | def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap: 129 | score_matrix_per_frame = ( 130 | np.stack( # (local_speakers, num_frames, global_speakers) 131 | [ 132 | scores1[:, speaker : speaker + 1] * scores2 133 | for speaker in range(scores1.shape[1]) 134 | ], 135 | axis=0, 136 | ) 137 | ) 138 | # Calculate total speech "activations" per local speaker 139 | local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1) 140 | # Calculate speaker mapping matrix 141 | # Cost matrix is the correlation divided by sum of local activations 142 | score_matrix = np.sum(score_matrix_per_frame, axis=1) / local_speech_scores 143 | # We want to maximize the correlation to calculate optimal speaker alignments 144 | return SpeakerMap(score_matrix, MaximizationObjective(max_value=1)) 145 | 146 | @staticmethod 147 | def mse(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap: 148 | cost_matrix = np.stack( # (local_speakers, local_speakers) 149 | [ 150 | np.square(scores1[:, speaker : speaker + 1] - scores2).mean(axis=0) 151 | for speaker in range(scores1.shape[1]) 152 | ], 153 | axis=0, 154 | ) 155 | # We want to minimize the MSE to calculate optimal speaker alignments 156 | return SpeakerMap(cost_matrix, MinimizationObjective()) 157 | 158 | @staticmethod 159 | def mae(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap: 160 | cost_matrix = np.stack( # (local_speakers, local_speakers) 161 | [ 162 | np.abs(scores1[:, speaker : speaker + 1] - scores2).mean(axis=0) 163 | for speaker in range(scores1.shape[1]) 164 | ], 165 | axis=0, 166 | ) 167 | # We want to minimize the MSE to calculate optimal speaker alignments 168 | return SpeakerMap(cost_matrix, MinimizationObjective()) 169 | 170 | @staticmethod 171 | def dist( 172 | embeddings1: np.ndarray, embeddings2: np.ndarray, metric: Text = "cosine" 173 | ) -> SpeakerMap: 174 | # We want to minimize the distance to calculate optimal speaker alignments 175 | dist_matrix = cdist(embeddings1, embeddings2, metric=metric) 176 | return SpeakerMap(dist_matrix, MinimizationObjective()) 177 | 178 | 179 | class SpeakerMap: 180 | def __init__(self, mapping_matrix: np.ndarray, objective: MappingMatrixObjective): 181 | self.mapping_matrix = mapping_matrix 182 | self.objective = objective 183 | self.num_source_speakers = self.mapping_matrix.shape[0] 184 | self.num_target_speakers = self.mapping_matrix.shape[1] 185 | self.mapped_source_speakers = self.objective.mapped_indices( 186 | self.mapping_matrix, axis=1 187 | ) 188 | self.mapped_target_speakers = self.objective.mapped_indices( 189 | self.mapping_matrix, axis=0 190 | ) 191 | self._opt_assignments: Optional[List[int]] = None 192 | 193 | @property 194 | def _raw_optimal_assignments(self) -> List[int]: 195 | if self._opt_assignments is None: 196 | self._opt_assignments = self.objective.optimal_assignments( 197 | self.mapping_matrix 198 | ) 199 | return self._opt_assignments 200 | 201 | @property 202 | def shape(self) -> Tuple[int, int]: 203 | return self.mapping_matrix.shape 204 | 205 | def __len__(self): 206 | return len(self.mapped_source_speakers) 207 | 208 | def __add__(self, other: SpeakerMap) -> SpeakerMap: 209 | return self.union(other) 210 | 211 | def _strict_check_valid(self, src: int, tgt: int) -> bool: 212 | return self.mapping_matrix[src, tgt] != self.objective.invalid_value 213 | 214 | def _loose_check_valid(self, src: int, tgt: int) -> bool: 215 | return self.is_source_speaker_mapped(src) 216 | 217 | def valid_assignments( 218 | self, 219 | strict: bool = False, 220 | as_array: bool = False, 221 | ) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]: 222 | source, target = [], [] 223 | val_type = "strict" if strict else "loose" 224 | is_valid = getattr(self, f"_{val_type}_check_valid") 225 | for src, tgt in enumerate(self._raw_optimal_assignments): 226 | if is_valid(src, tgt): 227 | source.append(src) 228 | target.append(tgt) 229 | if as_array: 230 | source, target = np.array(source), np.array(target) 231 | return source, target 232 | 233 | def to_dict(self, strict: bool = False) -> Dict[int, int]: 234 | return {src: tgt for src, tgt in zip(*self.valid_assignments(strict))} 235 | 236 | def to_inverse_dict(self, strict: bool = False) -> Dict[int, int]: 237 | return {tgt: src for src, tgt in zip(*self.valid_assignments(strict))} 238 | 239 | def is_source_speaker_mapped(self, source_speaker: int) -> bool: 240 | return source_speaker in self.mapped_source_speakers 241 | 242 | def is_target_speaker_mapped(self, target_speaker: int) -> bool: 243 | return target_speaker in self.mapped_target_speakers 244 | 245 | def set_source_speaker(self, src_speaker, tgt_speaker: int): 246 | # if not force: 247 | # assert not self.is_source_speaker_mapped(src_speaker) 248 | # assert not self.is_target_speaker_mapped(tgt_speaker) 249 | new_cost_matrix = np.copy(self.mapping_matrix) 250 | new_cost_matrix[src_speaker, tgt_speaker] = self.objective.best_possible_value 251 | return SpeakerMap(new_cost_matrix, self.objective) 252 | 253 | def unmap_source_speaker(self, src_speaker: int): 254 | new_cost_matrix = np.copy(self.mapping_matrix) 255 | new_cost_matrix[src_speaker] = self.objective.invalid_tensor( 256 | shape=self.num_target_speakers 257 | ) 258 | return SpeakerMap(new_cost_matrix, self.objective) 259 | 260 | def unmap_threshold(self, threshold: float) -> SpeakerMap: 261 | def is_invalid(val): 262 | if self.objective.maximize: 263 | return val <= threshold 264 | else: 265 | return val >= threshold 266 | 267 | return self.unmap_speakers( 268 | [ 269 | src 270 | for src, tgt in zip(*self.valid_assignments()) 271 | if is_invalid(self.mapping_matrix[src, tgt]) 272 | ] 273 | ) 274 | 275 | def unmap_speakers( 276 | self, 277 | source_speakers: Optional[Union[List[int], np.ndarray]] = None, 278 | target_speakers: Optional[Union[List[int], np.ndarray]] = None, 279 | ) -> SpeakerMap: 280 | # Set invalid_value to disabled speakers. 281 | # If they happen to be the best mapping for a local speaker, 282 | # it means that the mapping of the local speaker should be ignored. 283 | source_speakers = [] if source_speakers is None else source_speakers 284 | target_speakers = [] if target_speakers is None else target_speakers 285 | new_cost_matrix = np.copy(self.mapping_matrix) 286 | for speaker1 in source_speakers: 287 | new_cost_matrix[speaker1] = self.objective.invalid_tensor( 288 | shape=self.num_target_speakers 289 | ) 290 | for speaker2 in target_speakers: 291 | new_cost_matrix[:, speaker2] = self.objective.invalid_tensor( 292 | shape=self.num_source_speakers 293 | ) 294 | return SpeakerMap(new_cost_matrix, self.objective) 295 | 296 | def compose(self, other: SpeakerMap) -> SpeakerMap: 297 | """Let's say that `self` is a mapping of `source_speakers` to `intermediate_speakers` 298 | and `other` is a mapping from `intermediate_speakers` to `target_speakers`. 299 | 300 | Compose `self` with `other` to obtain a new mapping from `source_speakers` to `target_speakers`. 301 | """ 302 | new_cost_matrix = other.objective.invalid_tensor( 303 | shape=(self.num_source_speakers, other.num_target_speakers) 304 | ) 305 | for src_speaker, intermediate_speaker in zip(*self.valid_assignments()): 306 | target_speaker = other.mapping_matrix[intermediate_speaker] 307 | new_cost_matrix[src_speaker] = target_speaker 308 | return SpeakerMap(new_cost_matrix, other.objective) 309 | 310 | def union(self, other: SpeakerMap): 311 | """`self` and `other` are two maps with the same dimensions. 312 | Return a new hard speaker map containing assignments in both maps. 313 | 314 | An assignment from `other` is ignored if it is in conflict with 315 | a source or target speaker from `self`. 316 | 317 | WARNING: The resulting map doesn't preserve soft assignments 318 | because `self` and `other` might have different objectives. 319 | 320 | :param other: SpeakerMap 321 | Another speaker map 322 | """ 323 | assert self.shape == other.shape 324 | best_val = self.objective.best_possible_value 325 | new_cost_matrix = self.objective.invalid_tensor(self.shape) 326 | self_src, self_tgt = self.valid_assignments() 327 | other_src, other_tgt = other.valid_assignments() 328 | for src in range(self.num_source_speakers): 329 | if src in self_src: 330 | # `self` is preserved by default 331 | tgt = self_tgt[self_src.index(src)] 332 | new_cost_matrix[src, tgt] = best_val 333 | elif src in other_src: 334 | # In order to add an assignment from `other`, 335 | # the target speaker cannot be in conflict with `self` 336 | tgt = other_tgt[other_src.index(src)] 337 | if not self.is_target_speaker_mapped(tgt): 338 | new_cost_matrix[src, tgt] = best_val 339 | return SpeakerMap(new_cost_matrix, self.objective) 340 | 341 | def apply(self, source_scores: np.ndarray) -> np.ndarray: 342 | """Apply this mapping to a score matrix of source speakers 343 | to obtain the same scores aligned to target speakers. 344 | 345 | Parameters 346 | ---------- 347 | source_scores : SlidingWindowFeature, (num_frames, num_source_speakers) 348 | Source speaker scores per frame. 349 | 350 | Returns 351 | ------- 352 | projected_scores : SlidingWindowFeature, (num_frames, num_target_speakers) 353 | Score matrix for target speakers. 354 | """ 355 | # Map local speaker scores to the most probable global speaker. Unknown scores are set to 0 356 | num_frames = source_scores.data.shape[0] 357 | projected_scores = np.zeros((num_frames, self.num_target_speakers)) 358 | for src_speaker, tgt_speaker in zip(*self.valid_assignments()): 359 | projected_scores[:, tgt_speaker] = source_scores[:, src_speaker] 360 | return projected_scores 361 | -------------------------------------------------------------------------------- /src/diart/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC 4 | from pathlib import Path 5 | from typing import Optional, Text, Union, Callable, List 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from requests import HTTPError 11 | 12 | try: 13 | from pyannote.audio import Model 14 | from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding 15 | from pyannote.audio.utils.powerset import Powerset 16 | 17 | IS_PYANNOTE_AVAILABLE = True 18 | except ImportError: 19 | IS_PYANNOTE_AVAILABLE = False 20 | 21 | try: 22 | import onnxruntime as ort 23 | 24 | IS_ONNX_AVAILABLE = True 25 | except ImportError: 26 | IS_ONNX_AVAILABLE = False 27 | 28 | 29 | class PowersetAdapter(nn.Module): 30 | def __init__(self, segmentation_model: nn.Module): 31 | super().__init__() 32 | self.model = segmentation_model 33 | specs = self.model.specifications 34 | max_speakers_per_frame = specs.powerset_max_classes 35 | max_speakers_per_chunk = len(specs.classes) 36 | self.powerset = Powerset(max_speakers_per_chunk, max_speakers_per_frame) 37 | 38 | def forward(self, waveform: torch.Tensor) -> torch.Tensor: 39 | return self.powerset.to_multilabel(self.model(waveform)) 40 | 41 | 42 | class PyannoteLoader: 43 | def __init__(self, model_info, hf_token: Union[Text, bool, None] = True): 44 | super().__init__() 45 | self.model_info = model_info 46 | self.hf_token = hf_token 47 | 48 | def __call__(self) -> Callable: 49 | try: 50 | model = Model.from_pretrained(self.model_info, use_auth_token=self.hf_token) 51 | specs = getattr(model, "specifications", None) 52 | if specs is not None and specs.powerset: 53 | model = PowersetAdapter(model) 54 | return model 55 | except HTTPError: 56 | pass 57 | except ModuleNotFoundError: 58 | pass 59 | return PretrainedSpeakerEmbedding(self.model_info, use_auth_token=self.hf_token) 60 | 61 | 62 | class ONNXLoader: 63 | def __init__(self, path: str | Path, input_names: List[str], output_name: str): 64 | super().__init__() 65 | self.path = Path(path) 66 | self.input_names = input_names 67 | self.output_name = output_name 68 | 69 | def __call__(self) -> ONNXModel: 70 | return ONNXModel(self.path, self.input_names, self.output_name) 71 | 72 | 73 | class ONNXModel: 74 | def __init__(self, path: Path, input_names: List[str], output_name: str): 75 | super().__init__() 76 | self.path = path 77 | self.input_names = input_names 78 | self.output_name = output_name 79 | self.device = torch.device("cpu") 80 | self.session = None 81 | self.recreate_session() 82 | 83 | @property 84 | def execution_provider(self) -> str: 85 | device = "CUDA" if self.device.type == "cuda" else "CPU" 86 | return f"{device}ExecutionProvider" 87 | 88 | def recreate_session(self): 89 | options = ort.SessionOptions() 90 | options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL 91 | self.session = ort.InferenceSession( 92 | self.path, 93 | sess_options=options, 94 | providers=[self.execution_provider], 95 | ) 96 | 97 | def to(self, device: torch.device) -> ONNXModel: 98 | if device.type != self.device.type: 99 | self.device = device 100 | self.recreate_session() 101 | return self 102 | 103 | def __call__(self, *args) -> torch.Tensor: 104 | inputs = { 105 | name: arg.cpu().numpy().astype(np.float32) 106 | for name, arg in zip(self.input_names, args) 107 | } 108 | output = self.session.run([self.output_name], inputs)[0] 109 | return torch.from_numpy(output).float().to(args[0].device) 110 | 111 | 112 | class LazyModel(ABC): 113 | def __init__(self, loader: Callable[[], Callable]): 114 | super().__init__() 115 | self.get_model = loader 116 | self.model: Optional[Callable] = None 117 | 118 | def is_in_memory(self) -> bool: 119 | """Return whether the model has been loaded into memory""" 120 | return self.model is not None 121 | 122 | def load(self): 123 | if not self.is_in_memory(): 124 | self.model = self.get_model() 125 | 126 | def to(self, device: torch.device) -> LazyModel: 127 | self.load() 128 | self.model = self.model.to(device) 129 | return self 130 | 131 | def __call__(self, *args, **kwargs): 132 | self.load() 133 | return self.model(*args, **kwargs) 134 | 135 | def eval(self) -> LazyModel: 136 | self.load() 137 | if isinstance(self.model, nn.Module): 138 | self.model.eval() 139 | return self 140 | 141 | 142 | class SegmentationModel(LazyModel): 143 | """ 144 | Minimal interface for a segmentation model. 145 | """ 146 | 147 | @staticmethod 148 | def from_pyannote( 149 | model, use_hf_token: Union[Text, bool, None] = True 150 | ) -> "SegmentationModel": 151 | """ 152 | Returns a `SegmentationModel` wrapping a pyannote model. 153 | 154 | Parameters 155 | ---------- 156 | model: pyannote.PipelineModel 157 | The pyannote.audio model to fetch. 158 | use_hf_token: str | bool, optional 159 | The Huggingface access token to use when downloading the model. 160 | If True, use huggingface-cli login token. 161 | Defaults to None. 162 | 163 | Returns 164 | ------- 165 | wrapper: SegmentationModel 166 | """ 167 | assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found" 168 | return SegmentationModel(PyannoteLoader(model, use_hf_token)) 169 | 170 | @staticmethod 171 | def from_onnx( 172 | model_path: Union[str, Path], 173 | input_name: str = "waveform", 174 | output_name: str = "segmentation", 175 | ) -> "SegmentationModel": 176 | assert IS_ONNX_AVAILABLE, "No ONNX installation found" 177 | return SegmentationModel(ONNXLoader(model_path, [input_name], output_name)) 178 | 179 | @staticmethod 180 | def from_pretrained( 181 | model, use_hf_token: Union[Text, bool, None] = True 182 | ) -> "SegmentationModel": 183 | if isinstance(model, str) or isinstance(model, Path): 184 | if Path(model).name.endswith(".onnx"): 185 | return SegmentationModel.from_onnx(model) 186 | return SegmentationModel.from_pyannote(model, use_hf_token) 187 | 188 | def __call__(self, waveform: torch.Tensor) -> torch.Tensor: 189 | """ 190 | Call the forward pass of the segmentation model. 191 | Parameters 192 | ---------- 193 | waveform: torch.Tensor, shape (batch, channels, samples) 194 | Returns 195 | ------- 196 | speaker_segmentation: torch.Tensor, shape (batch, frames, speakers) 197 | """ 198 | return super().__call__(waveform) 199 | 200 | 201 | class EmbeddingModel(LazyModel): 202 | """Minimal interface for an embedding model.""" 203 | 204 | @staticmethod 205 | def from_pyannote( 206 | model, use_hf_token: Union[Text, bool, None] = True 207 | ) -> "EmbeddingModel": 208 | """ 209 | Returns an `EmbeddingModel` wrapping a pyannote model. 210 | 211 | Parameters 212 | ---------- 213 | model: pyannote.PipelineModel 214 | The pyannote.audio model to fetch. 215 | use_hf_token: str | bool, optional 216 | The Huggingface access token to use when downloading the model. 217 | If True, use huggingface-cli login token. 218 | Defaults to None. 219 | 220 | Returns 221 | ------- 222 | wrapper: EmbeddingModel 223 | """ 224 | assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found" 225 | loader = PyannoteLoader(model, use_hf_token) 226 | return EmbeddingModel(loader) 227 | 228 | @staticmethod 229 | def from_onnx( 230 | model_path: Union[str, Path], 231 | input_names: List[str] | None = None, 232 | output_name: str = "embedding", 233 | ) -> "EmbeddingModel": 234 | assert IS_ONNX_AVAILABLE, "No ONNX installation found" 235 | input_names = input_names or ["waveform", "weights"] 236 | loader = ONNXLoader(model_path, input_names, output_name) 237 | return EmbeddingModel(loader) 238 | 239 | @staticmethod 240 | def from_pretrained( 241 | model, use_hf_token: Union[Text, bool, None] = True 242 | ) -> "EmbeddingModel": 243 | if isinstance(model, str) or isinstance(model, Path): 244 | if Path(model).name.endswith(".onnx"): 245 | return EmbeddingModel.from_onnx(model) 246 | return EmbeddingModel.from_pyannote(model, use_hf_token) 247 | 248 | def __call__( 249 | self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None 250 | ) -> torch.Tensor: 251 | """ 252 | Call the forward pass of an embedding model with optional weights. 253 | Parameters 254 | ---------- 255 | waveform: torch.Tensor, shape (batch, channels, samples) 256 | weights: Optional[torch.Tensor], shape (batch, frames) 257 | Temporal weights for each sample in the batch. Defaults to no weights. 258 | Returns 259 | ------- 260 | speaker_embeddings: torch.Tensor, shape (batch, embedding_dim) 261 | """ 262 | embeddings = super().__call__(waveform, weights) 263 | if isinstance(embeddings, np.ndarray): 264 | embeddings = torch.from_numpy(embeddings) 265 | return embeddings 266 | -------------------------------------------------------------------------------- /src/diart/operators.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Optional, List, Any, Tuple 3 | 4 | import numpy as np 5 | import rx 6 | from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment 7 | from rx import operators as ops 8 | from rx.core import Observable 9 | 10 | Operator = Callable[[Observable], Observable] 11 | 12 | 13 | @dataclass 14 | class AudioBufferState: 15 | chunk: Optional[np.ndarray] 16 | buffer: Optional[np.ndarray] 17 | start_time: float 18 | changed: bool 19 | 20 | @staticmethod 21 | def initial(): 22 | return AudioBufferState(None, None, 0, False) 23 | 24 | @staticmethod 25 | def has_samples(num_samples: int): 26 | def call_fn(state) -> bool: 27 | return state.chunk is not None and state.chunk.shape[1] == num_samples 28 | 29 | return call_fn 30 | 31 | @staticmethod 32 | def to_sliding_window(sample_rate: int): 33 | def call_fn(state) -> SlidingWindowFeature: 34 | resolution = SlidingWindow( 35 | start=state.start_time, 36 | duration=1.0 / sample_rate, 37 | step=1.0 / sample_rate, 38 | ) 39 | return SlidingWindowFeature(state.chunk.T, resolution) 40 | 41 | return call_fn 42 | 43 | 44 | def rearrange_audio_stream( 45 | duration: float = 5, step: float = 0.5, sample_rate: int = 16000 46 | ) -> Operator: 47 | chunk_samples = int(round(sample_rate * duration)) 48 | step_samples = int(round(sample_rate * step)) 49 | 50 | # FIXME this should flush buffer contents when the audio stops being emitted. 51 | # Right now this can be solved by using a block size that's a dividend of the step size. 52 | 53 | def accumulate(state: AudioBufferState, value: np.ndarray): 54 | # State contains the last emitted chunk, the current step buffer and the last starting time 55 | if value.ndim != 2 or value.shape[0] != 1: 56 | raise ValueError( 57 | f"Waveform must have shape (1, samples) but {value.shape} was found" 58 | ) 59 | start_time = state.start_time 60 | 61 | # Add new samples to the buffer 62 | buffer = ( 63 | value 64 | if state.buffer is None 65 | else np.concatenate([state.buffer, value], axis=1) 66 | ) 67 | 68 | # Check for buffer overflow 69 | if buffer.shape[1] >= step_samples: 70 | # Pop samples from buffer 71 | if buffer.shape[1] == step_samples: 72 | new_chunk, new_buffer = buffer, None 73 | else: 74 | new_chunk = buffer[:, :step_samples] 75 | new_buffer = buffer[:, step_samples:] 76 | 77 | # Add samples to next chunk 78 | if state.chunk is not None: 79 | new_chunk = np.concatenate([state.chunk, new_chunk], axis=1) 80 | 81 | # Truncate chunk to ensure a fixed duration 82 | if new_chunk.shape[1] > chunk_samples: 83 | new_chunk = new_chunk[:, -chunk_samples:] 84 | start_time += step 85 | 86 | # Chunk has changed because of buffer overflow 87 | return AudioBufferState(new_chunk, new_buffer, start_time, changed=True) 88 | 89 | # Chunk has not changed 90 | return AudioBufferState(state.chunk, buffer, start_time, changed=False) 91 | 92 | return rx.pipe( 93 | # Accumulate last <=duration seconds of waveform as an AudioBufferState 94 | ops.scan(accumulate, AudioBufferState.initial()), 95 | # Take only states that have the desired duration and whose chunk has changed since last time 96 | ops.filter(AudioBufferState.has_samples(chunk_samples)), 97 | ops.filter(lambda state: state.changed), 98 | # Transform state into a SlidingWindowFeature containing the new chunk 99 | ops.map(AudioBufferState.to_sliding_window(sample_rate)), 100 | ) 101 | 102 | 103 | def buffer_slide(n: int): 104 | def accumulate(state: List[Any], value: Any) -> List[Any]: 105 | new_state = [*state, value] 106 | if len(new_state) > n: 107 | return new_state[1:] 108 | return new_state 109 | 110 | return rx.pipe(ops.scan(accumulate, [])) 111 | 112 | 113 | @dataclass 114 | class PredictionWithAudio: 115 | prediction: Annotation 116 | waveform: Optional[SlidingWindowFeature] = None 117 | 118 | @property 119 | def has_audio(self) -> bool: 120 | return self.waveform is not None 121 | 122 | 123 | @dataclass 124 | class OutputAccumulationState: 125 | annotation: Optional[Annotation] 126 | waveform: Optional[SlidingWindowFeature] 127 | real_time: float 128 | next_sample: Optional[int] 129 | 130 | @staticmethod 131 | def initial() -> "OutputAccumulationState": 132 | return OutputAccumulationState(None, None, 0, 0) 133 | 134 | @property 135 | def cropped_waveform(self) -> SlidingWindowFeature: 136 | return SlidingWindowFeature( 137 | self.waveform[: self.next_sample], 138 | self.waveform.sliding_window, 139 | ) 140 | 141 | def to_tuple( 142 | self, 143 | ) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]: 144 | return self.annotation, self.cropped_waveform, self.real_time 145 | 146 | 147 | def accumulate_output( 148 | duration: float, 149 | step: float, 150 | patch_collar: float = 0.05, 151 | ) -> Operator: 152 | """Accumulate predictions and audio to infinity: O(N) space complexity. 153 | Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations. 154 | 155 | Parameters 156 | ---------- 157 | duration: float 158 | Buffer duration in seconds. 159 | step: float 160 | Duration of the chunks at each event in seconds. 161 | The first chunk may be bigger given the latency. 162 | patch_collar: float, optional 163 | Collar to merge speaker turns of the same speaker, in seconds. 164 | Defaults to 0.05 (i.e. 50ms). 165 | Returns 166 | ------- 167 | A reactive x operator implementing this behavior. 168 | """ 169 | 170 | def accumulate( 171 | state: OutputAccumulationState, 172 | value: Tuple[Annotation, Optional[SlidingWindowFeature]], 173 | ) -> OutputAccumulationState: 174 | value = PredictionWithAudio(*value) 175 | annotation, waveform = None, None 176 | 177 | # Determine the real time of the stream 178 | real_time = duration if state.annotation is None else state.real_time + step 179 | 180 | # Update total annotation with current predictions 181 | if state.annotation is None: 182 | annotation = value.prediction 183 | else: 184 | annotation = state.annotation.update(value.prediction).support(patch_collar) 185 | 186 | # Update total waveform if there's audio in the input 187 | new_next_sample = 0 188 | if value.has_audio: 189 | num_new_samples = value.waveform.data.shape[0] 190 | new_next_sample = state.next_sample + num_new_samples 191 | sw_holder = state 192 | if state.waveform is None: 193 | # Initialize the audio buffer with 10 times the size of the first chunk 194 | waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value 195 | elif new_next_sample < state.waveform.data.shape[0]: 196 | # The buffer still has enough space to accommodate the chunk 197 | waveform = state.waveform.data 198 | else: 199 | # The buffer is full, double its size 200 | waveform = np.concatenate( 201 | (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0 202 | ) 203 | # Copy chunk into buffer 204 | waveform[state.next_sample : new_next_sample] = value.waveform.data 205 | waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window) 206 | 207 | return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) 208 | 209 | return rx.pipe( 210 | ops.scan(accumulate, OutputAccumulationState.initial()), 211 | ops.map(OutputAccumulationState.to_tuple), 212 | ) 213 | 214 | 215 | def buffer_output( 216 | duration: float, 217 | step: float, 218 | latency: float, 219 | sample_rate: int, 220 | patch_collar: float = 0.05, 221 | ) -> Operator: 222 | """Store last predictions and audio inside a fixed buffer. 223 | Provides the best time/space complexity trade-off if the past data is not needed. 224 | 225 | Parameters 226 | ---------- 227 | duration: float 228 | Buffer duration in seconds. 229 | step: float 230 | Duration of the chunks at each event in seconds. 231 | The first chunk may be bigger given the latency. 232 | latency: float 233 | Latency of the system in seconds. 234 | sample_rate: int 235 | Sample rate of the audio source. 236 | patch_collar: float, optional 237 | Collar to merge speaker turns of the same speaker, in seconds. 238 | Defaults to 0.05 (i.e. 50ms). 239 | 240 | Returns 241 | ------- 242 | A reactive x operator implementing this behavior. 243 | """ 244 | # Define some useful constants 245 | num_samples = int(round(duration * sample_rate)) 246 | num_step_samples = int(round(step * sample_rate)) 247 | resolution = 1 / sample_rate 248 | 249 | def accumulate( 250 | state: OutputAccumulationState, 251 | value: Tuple[Annotation, Optional[SlidingWindowFeature]], 252 | ) -> OutputAccumulationState: 253 | value = PredictionWithAudio(*value) 254 | annotation, waveform = None, None 255 | 256 | # Determine the real time of the stream and the start time of the buffer 257 | real_time = duration if state.annotation is None else state.real_time + step 258 | start_time = max(0.0, real_time - latency - duration) 259 | 260 | # Update annotation and constrain its bounds to the buffer 261 | if state.annotation is None: 262 | annotation = value.prediction 263 | else: 264 | annotation = state.annotation.update(value.prediction).support(patch_collar) 265 | if start_time > 0: 266 | annotation = annotation.extrude(Segment(0, start_time)) 267 | 268 | # Update the audio buffer if there's audio in the input 269 | new_next_sample = state.next_sample + num_step_samples 270 | if value.has_audio: 271 | if state.waveform is None: 272 | # Determine the size of the first chunk 273 | expected_duration = duration + step - latency 274 | expected_samples = int(round(expected_duration * sample_rate)) 275 | # Shift indicator to start copying new audio in the buffer 276 | new_next_sample = state.next_sample + expected_samples 277 | # Buffer size is duration + step 278 | waveform = np.zeros((num_samples + num_step_samples, 1)) 279 | # Copy first chunk into buffer (slicing because of rounding errors) 280 | waveform[:expected_samples] = value.waveform.data[:expected_samples] 281 | elif state.next_sample <= num_samples: 282 | # The buffer isn't full, copy into next free buffer chunk 283 | waveform = state.waveform.data 284 | waveform[state.next_sample : new_next_sample] = value.waveform.data 285 | else: 286 | # The buffer is full, shift values to the left and copy into last buffer chunk 287 | waveform = np.roll(state.waveform.data, -num_step_samples, axis=0) 288 | # If running on a file, the online prediction may be shorter depending on the latency 289 | # The remaining audio at the end is appended, so value.waveform may be longer than num_step_samples 290 | # In that case, we simply ignore the appended samples. 291 | waveform[-num_step_samples:] = value.waveform.data[:num_step_samples] 292 | 293 | # Wrap waveform in a sliding window feature to include timestamps 294 | window = SlidingWindow( 295 | start=start_time, duration=resolution, step=resolution 296 | ) 297 | waveform = SlidingWindowFeature(waveform, window) 298 | 299 | return OutputAccumulationState(annotation, waveform, real_time, new_next_sample) 300 | 301 | return rx.pipe( 302 | ops.scan(accumulate, OutputAccumulationState.initial()), 303 | ops.map(OutputAccumulationState.to_tuple), 304 | ) 305 | -------------------------------------------------------------------------------- /src/diart/optim.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from pathlib import Path 3 | from typing import Sequence, Text, Optional, Union 4 | 5 | from optuna import TrialPruned, Study, create_study 6 | from optuna.samplers import TPESampler 7 | from optuna.trial import Trial, FrozenTrial 8 | from pyannote.metrics.base import BaseMetric 9 | from tqdm import trange, tqdm 10 | from typing_extensions import Literal 11 | 12 | from . import blocks 13 | from .audio import FilePath 14 | from .inference import Benchmark 15 | 16 | 17 | class Optimizer: 18 | def __init__( 19 | self, 20 | pipeline_class: type, 21 | speech_path: Union[Text, Path], 22 | reference_path: Union[Text, Path], 23 | study_or_path: Union[FilePath, Study], 24 | batch_size: int = 32, 25 | hparams: Optional[Sequence[blocks.base.HyperParameter]] = None, 26 | base_config: Optional[blocks.PipelineConfig] = None, 27 | do_kickstart_hparams: bool = True, 28 | metric: Optional[BaseMetric] = None, 29 | direction: Literal["minimize", "maximize"] = "minimize", 30 | ): 31 | self.pipeline_class = pipeline_class 32 | # FIXME can we run this benchmark in parallel? 33 | # Currently it breaks the trial progress bar 34 | self.benchmark = Benchmark( 35 | speech_path, 36 | reference_path, 37 | show_progress=True, 38 | show_report=False, 39 | batch_size=batch_size, 40 | ) 41 | 42 | self.metric = metric 43 | self.direction = direction 44 | self.base_config = base_config 45 | self.do_kickstart_hparams = do_kickstart_hparams 46 | if self.base_config is None: 47 | self.base_config = self.pipeline_class.get_config_class()() 48 | self.do_kickstart_hparams = False 49 | 50 | self.hparams = hparams 51 | if self.hparams is None: 52 | self.hparams = self.pipeline_class.hyper_parameters() 53 | 54 | # Make sure hyper-parameters exist in the configuration class given 55 | possible_hparams = vars(self.base_config) 56 | for param in self.hparams: 57 | msg = ( 58 | f"Hyper-parameter {param.name} not found " 59 | f"in configuration {self.base_config.__class__.__name__}" 60 | ) 61 | assert param.name in possible_hparams, msg 62 | 63 | self._progress: Optional[tqdm] = None 64 | 65 | if isinstance(study_or_path, Study): 66 | self.study = study_or_path 67 | elif isinstance(study_or_path, str) or isinstance(study_or_path, Path): 68 | study_or_path = Path(study_or_path) 69 | self.study = create_study( 70 | storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"), 71 | sampler=TPESampler(), 72 | study_name=study_or_path.stem, 73 | direction=self.direction, 74 | load_if_exists=True, 75 | ) 76 | else: 77 | msg = f"Expected Study object or path-like, but got {type(study_or_path).__name__}" 78 | raise ValueError(msg) 79 | 80 | @property 81 | def best_performance(self): 82 | return self.study.best_value 83 | 84 | @property 85 | def best_hparams(self): 86 | return self.study.best_params 87 | 88 | def _callback(self, study: Study, trial: FrozenTrial): 89 | if self._progress is None: 90 | return 91 | self._progress.update(1) 92 | self._progress.set_description(f"Trial {trial.number + 1}") 93 | values = {"best_perf": study.best_value} 94 | for name, value in study.best_params.items(): 95 | values[f"best_{name}"] = value 96 | self._progress.set_postfix(OrderedDict(values)) 97 | 98 | def objective(self, trial: Trial) -> float: 99 | # Set suggested values for optimized hyper-parameters 100 | trial_config = vars(self.base_config) 101 | for hparam in self.hparams: 102 | trial_config[hparam.name] = trial.suggest_uniform( 103 | hparam.name, hparam.low, hparam.high 104 | ) 105 | 106 | # Prune trial if required 107 | if trial.should_prune(): 108 | raise TrialPruned() 109 | 110 | # Instantiate the new configuration for the trial 111 | config = self.base_config.__class__(**trial_config) 112 | 113 | # Determine the evaluation metric 114 | metric = self.metric 115 | if metric is None: 116 | metric = self.pipeline_class.suggest_metric() 117 | 118 | # Run pipeline over the dataset 119 | report = self.benchmark(self.pipeline_class, config, metric) 120 | 121 | # Extract target metric from report 122 | return report.loc["TOTAL", metric.name]["%"] 123 | 124 | def __call__(self, num_iter: int, show_progress: bool = True): 125 | self._progress = None 126 | if show_progress: 127 | self._progress = trange(num_iter) 128 | last_trial = -1 129 | if self.study.trials: 130 | last_trial = self.study.trials[-1].number 131 | self._progress.set_description(f"Trial {last_trial + 1}") 132 | # Start with base config hyper-parameters if config was given 133 | if self.do_kickstart_hparams: 134 | self.study.enqueue_trial( 135 | { 136 | param.name: getattr(self.base_config, param.name) 137 | for param in self.hparams 138 | }, 139 | skip_if_exists=True, 140 | ) 141 | self.study.optimize(self.objective, num_iter, callbacks=[self._callback]) 142 | -------------------------------------------------------------------------------- /src/diart/progress.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional, Text 3 | 4 | import rich 5 | from rich.progress import Progress, TaskID 6 | from tqdm import tqdm 7 | 8 | 9 | class ProgressBar(ABC): 10 | @abstractmethod 11 | def create( 12 | self, 13 | total: int, 14 | description: Optional[Text] = None, 15 | unit: Text = "it", 16 | **kwargs, 17 | ): 18 | pass 19 | 20 | @abstractmethod 21 | def start(self): 22 | pass 23 | 24 | @abstractmethod 25 | def update(self, n: int = 1): 26 | pass 27 | 28 | @abstractmethod 29 | def write(self, text: Text): 30 | pass 31 | 32 | @abstractmethod 33 | def stop(self): 34 | pass 35 | 36 | @abstractmethod 37 | def close(self): 38 | pass 39 | 40 | @property 41 | @abstractmethod 42 | def default_description(self) -> Text: 43 | pass 44 | 45 | @property 46 | @abstractmethod 47 | def initial_description(self) -> Optional[Text]: 48 | pass 49 | 50 | def resolve_description(self, new_description: Optional[Text] = None) -> Text: 51 | if self.initial_description is None: 52 | if new_description is None: 53 | return self.default_description 54 | return new_description 55 | else: 56 | return self.initial_description 57 | 58 | 59 | class RichProgressBar(ProgressBar): 60 | def __init__( 61 | self, 62 | description: Optional[Text] = None, 63 | color: Text = "green", 64 | leave: bool = True, 65 | do_close: bool = True, 66 | ): 67 | self.description = description 68 | self.color = color 69 | self.do_close = do_close 70 | self.bar = Progress(transient=not leave) 71 | self.bar.start() 72 | self.task_id: Optional[TaskID] = None 73 | 74 | @property 75 | def default_description(self) -> Text: 76 | return f"[{self.color}]Streaming" 77 | 78 | @property 79 | def initial_description(self) -> Optional[Text]: 80 | if self.description is not None: 81 | return f"[{self.color}]{self.description}" 82 | return self.description 83 | 84 | def create( 85 | self, 86 | total: int, 87 | description: Optional[Text] = None, 88 | unit: Text = "it", 89 | **kwargs, 90 | ): 91 | if self.task_id is None: 92 | self.task_id = self.bar.add_task( 93 | self.resolve_description(f"[{self.color}]{description}"), 94 | start=False, 95 | total=total, 96 | completed=0, 97 | visible=True, 98 | **kwargs, 99 | ) 100 | 101 | def start(self): 102 | assert self.task_id is not None 103 | self.bar.start_task(self.task_id) 104 | 105 | def update(self, n: int = 1): 106 | assert self.task_id is not None 107 | self.bar.update(self.task_id, advance=n) 108 | 109 | def write(self, text: Text): 110 | rich.print(text) 111 | 112 | def stop(self): 113 | assert self.task_id is not None 114 | self.bar.stop_task(self.task_id) 115 | 116 | def close(self): 117 | if self.do_close: 118 | self.bar.stop() 119 | 120 | 121 | class TQDMProgressBar(ProgressBar): 122 | def __init__( 123 | self, 124 | description: Optional[Text] = None, 125 | leave: bool = True, 126 | position: Optional[int] = None, 127 | do_close: bool = True, 128 | ): 129 | self.description = description 130 | self.leave = leave 131 | self.position = position 132 | self.do_close = do_close 133 | self.pbar: Optional[tqdm] = None 134 | 135 | @property 136 | def default_description(self) -> Text: 137 | return "Streaming" 138 | 139 | @property 140 | def initial_description(self) -> Optional[Text]: 141 | return self.description 142 | 143 | def create( 144 | self, 145 | total: int, 146 | description: Optional[Text] = None, 147 | unit: Optional[Text] = "it", 148 | **kwargs, 149 | ): 150 | if self.pbar is None: 151 | self.pbar = tqdm( 152 | desc=self.resolve_description(description), 153 | total=total, 154 | unit=unit, 155 | leave=self.leave, 156 | position=self.position, 157 | **kwargs, 158 | ) 159 | 160 | def start(self): 161 | pass 162 | 163 | def update(self, n: int = 1): 164 | assert self.pbar is not None 165 | self.pbar.update(n) 166 | 167 | def write(self, text: Text): 168 | tqdm.write(text) 169 | 170 | def stop(self): 171 | self.close() 172 | 173 | def close(self): 174 | if self.do_close: 175 | assert self.pbar is not None 176 | self.pbar.close() 177 | -------------------------------------------------------------------------------- /src/diart/sinks.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Text, Optional, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook 6 | from pyannote.database.util import load_rttm 7 | from pyannote.metrics.diarization import DiarizationErrorRate 8 | from rx.core import Observer 9 | from typing_extensions import Literal 10 | 11 | 12 | class WindowClosedException(Exception): 13 | pass 14 | 15 | 16 | def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation: 17 | if isinstance(value, tuple): 18 | return value[0] 19 | if isinstance(value, Annotation): 20 | return value 21 | msg = f"Expected tuple or Annotation, but got {type(value)}" 22 | raise ValueError(msg) 23 | 24 | 25 | class RTTMWriter(Observer): 26 | def __init__(self, uri: Text, path: Union[Path, Text], patch_collar: float = 0.05): 27 | super().__init__() 28 | self.uri = uri 29 | self.patch_collar = patch_collar 30 | self.path = Path(path).expanduser() 31 | if self.path.exists(): 32 | self.path.unlink() 33 | 34 | def patch(self): 35 | """Stitch same-speaker turns that are close to each other""" 36 | if not self.path.exists(): 37 | return 38 | annotations = list(load_rttm(self.path).values()) 39 | if annotations: 40 | annotation = annotations[0] 41 | annotation.uri = self.uri 42 | with open(self.path, "w") as file: 43 | annotation.support(self.patch_collar).write_rttm(file) 44 | 45 | def on_next(self, value: Union[Tuple, Annotation]): 46 | prediction = _extract_prediction(value) 47 | # Write prediction in RTTM format 48 | prediction.uri = self.uri 49 | with open(self.path, "a") as file: 50 | prediction.write_rttm(file) 51 | 52 | def on_error(self, error: Exception): 53 | self.patch() 54 | 55 | def on_completed(self): 56 | self.patch() 57 | 58 | 59 | class PredictionAccumulator(Observer): 60 | def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05): 61 | super().__init__() 62 | self.uri = uri 63 | self.patch_collar = patch_collar 64 | self._prediction: Optional[Annotation] = None 65 | 66 | def patch(self): 67 | """Stitch same-speaker turns that are close to each other""" 68 | if self._prediction is not None: 69 | self._prediction = self._prediction.support(self.patch_collar) 70 | 71 | def get_prediction(self) -> Annotation: 72 | # Patch again in case this is called before on_completed 73 | self.patch() 74 | return self._prediction 75 | 76 | def on_next(self, value: Union[Tuple, Annotation]): 77 | prediction = _extract_prediction(value) 78 | prediction.uri = self.uri 79 | if self._prediction is None: 80 | self._prediction = prediction 81 | else: 82 | self._prediction.update(prediction) 83 | 84 | def on_error(self, error: Exception): 85 | self.patch() 86 | 87 | def on_completed(self): 88 | self.patch() 89 | 90 | 91 | class StreamingPlot(Observer): 92 | def __init__( 93 | self, 94 | duration: float, 95 | latency: float, 96 | visualization: Literal["slide", "accumulate"] = "slide", 97 | reference: Optional[Union[Path, Text]] = None, 98 | ): 99 | super().__init__() 100 | assert visualization in ["slide", "accumulate"] 101 | self.visualization = visualization 102 | self.reference = reference 103 | if self.reference is not None: 104 | self.reference = list(load_rttm(reference).values())[0] 105 | self.window_duration = duration 106 | self.latency = latency 107 | self.figure, self.axs, self.num_axs = None, None, -1 108 | # This flag allows to catch the matplotlib window closed event and make the next call stop iterating 109 | self.window_closed = False 110 | 111 | def _on_window_closed(self, event): 112 | self.window_closed = True 113 | 114 | def _init_num_axs(self): 115 | if self.num_axs == -1: 116 | self.num_axs = 2 117 | if self.reference is not None: 118 | self.num_axs += 1 119 | 120 | def _init_figure(self): 121 | self._init_num_axs() 122 | self.figure, self.axs = plt.subplots( 123 | self.num_axs, 1, figsize=(10, 2 * self.num_axs) 124 | ) 125 | if self.num_axs == 1: 126 | self.axs = [self.axs] 127 | self.figure.canvas.mpl_connect("close_event", self._on_window_closed) 128 | 129 | def _clear_axs(self): 130 | for i in range(self.num_axs): 131 | self.axs[i].clear() 132 | 133 | def get_plot_bounds(self, real_time: float) -> Segment: 134 | start_time = 0 135 | end_time = real_time - self.latency 136 | if self.visualization == "slide": 137 | start_time = max(0.0, end_time - self.window_duration) 138 | return Segment(start_time, end_time) 139 | 140 | def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]): 141 | if self.window_closed: 142 | raise WindowClosedException 143 | 144 | prediction, waveform, real_time = values 145 | 146 | # Initialize figure if first call 147 | if self.figure is None: 148 | self._init_figure() 149 | # Clear previous plots 150 | self._clear_axs() 151 | # Set plot bounds 152 | notebook.crop = self.get_plot_bounds(real_time) 153 | 154 | # Align prediction and reference if possible 155 | if self.reference is not None: 156 | metric = DiarizationErrorRate() 157 | mapping = metric.optimal_mapping(self.reference, prediction) 158 | prediction.rename_labels(mapping=mapping, copy=False) 159 | 160 | # Plot prediction 161 | notebook.plot_annotation(prediction, self.axs[0]) 162 | self.axs[0].set_title("Output") 163 | 164 | # Plot waveform 165 | notebook.plot_feature(waveform, self.axs[1]) 166 | self.axs[1].set_title("Audio") 167 | 168 | # Plot reference if available 169 | if self.num_axs == 3: 170 | notebook.plot_annotation(self.reference, self.axs[2]) 171 | self.axs[2].set_title("Reference") 172 | 173 | # Draw 174 | plt.tight_layout() 175 | self.figure.canvas.draw() 176 | self.figure.canvas.flush_events() 177 | plt.pause(0.05) 178 | -------------------------------------------------------------------------------- /src/diart/sources.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from queue import SimpleQueue 4 | from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple 5 | 6 | import numpy as np 7 | import sounddevice as sd 8 | import torch 9 | from einops import rearrange 10 | from rx.subject import Subject 11 | from torchaudio.io import StreamReader 12 | from websocket_server import WebsocketServer 13 | 14 | from . import utils 15 | from .audio import FilePath, AudioLoader 16 | 17 | 18 | class AudioSource(ABC): 19 | """Represents a source of audio that can start streaming via the `stream` property. 20 | 21 | Parameters 22 | ---------- 23 | uri: Text 24 | Unique identifier of the audio source. 25 | sample_rate: int 26 | Sample rate of the audio source. 27 | """ 28 | 29 | def __init__(self, uri: Text, sample_rate: int): 30 | self.uri = uri 31 | self.sample_rate = sample_rate 32 | self.stream = Subject() 33 | 34 | @property 35 | def duration(self) -> Optional[float]: 36 | """The duration of the stream if known. Defaults to None (unknown duration).""" 37 | return None 38 | 39 | @abstractmethod 40 | def read(self): 41 | """Start reading the source and yielding samples through the stream.""" 42 | pass 43 | 44 | @abstractmethod 45 | def close(self): 46 | """Stop reading the source and close all open streams.""" 47 | pass 48 | 49 | 50 | class FileAudioSource(AudioSource): 51 | """Represents an audio source tied to a file. 52 | 53 | Parameters 54 | ---------- 55 | file: FilePath 56 | Path to the file to stream. 57 | sample_rate: int 58 | Sample rate of the chunks emitted. 59 | padding: (float, float) 60 | Left and right padding to add to the file (in seconds). 61 | Defaults to (0, 0). 62 | block_duration: int 63 | Duration of each emitted chunk in seconds. 64 | Defaults to 0.5 seconds. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | file: FilePath, 70 | sample_rate: int, 71 | padding: Tuple[float, float] = (0, 0), 72 | block_duration: float = 0.5, 73 | ): 74 | super().__init__(Path(file).stem, sample_rate) 75 | self.loader = AudioLoader(self.sample_rate, mono=True) 76 | self._duration = self.loader.get_duration(file) 77 | self.file = file 78 | self.resolution = 1 / self.sample_rate 79 | self.block_size = int(np.rint(block_duration * self.sample_rate)) 80 | self.padding_start, self.padding_end = padding 81 | self.is_closed = False 82 | 83 | @property 84 | def duration(self) -> Optional[float]: 85 | # The duration of a file is known 86 | return self.padding_start + self._duration + self.padding_end 87 | 88 | def read(self): 89 | """Send each chunk of samples through the stream""" 90 | waveform = self.loader.load(self.file) 91 | 92 | # Add zero padding at the beginning if required 93 | if self.padding_start > 0: 94 | num_pad_samples = int(np.rint(self.padding_start * self.sample_rate)) 95 | zero_padding = torch.zeros(waveform.shape[0], num_pad_samples) 96 | waveform = torch.cat([zero_padding, waveform], dim=1) 97 | 98 | # Add zero padding at the end if required 99 | if self.padding_end > 0: 100 | num_pad_samples = int(np.rint(self.padding_end * self.sample_rate)) 101 | zero_padding = torch.zeros(waveform.shape[0], num_pad_samples) 102 | waveform = torch.cat([waveform, zero_padding], dim=1) 103 | 104 | # Split into blocks 105 | _, num_samples = waveform.shape 106 | chunks = rearrange( 107 | waveform.unfold(1, self.block_size, self.block_size), 108 | "channel chunk sample -> chunk channel sample", 109 | ).numpy() 110 | 111 | # Add last incomplete chunk with padding 112 | if num_samples % self.block_size != 0: 113 | last_chunk = ( 114 | waveform[:, chunks.shape[0] * self.block_size :].unsqueeze(0).numpy() 115 | ) 116 | diff_samples = self.block_size - last_chunk.shape[-1] 117 | last_chunk = np.concatenate( 118 | [last_chunk, np.zeros((1, 1, diff_samples))], axis=-1 119 | ) 120 | chunks = np.vstack([chunks, last_chunk]) 121 | 122 | # Stream blocks 123 | for i, waveform in enumerate(chunks): 124 | try: 125 | if self.is_closed: 126 | break 127 | self.stream.on_next(waveform) 128 | except BaseException as e: 129 | self.stream.on_error(e) 130 | break 131 | self.stream.on_completed() 132 | self.close() 133 | 134 | def close(self): 135 | self.is_closed = True 136 | 137 | 138 | class MicrophoneAudioSource(AudioSource): 139 | """Audio source tied to a local microphone. 140 | 141 | Parameters 142 | ---------- 143 | block_duration: int 144 | Duration of each emitted chunk in seconds. 145 | Defaults to 0.5 seconds. 146 | device: int | str | (int, str) | None 147 | Device identifier compatible for the sounddevice stream. 148 | If None, use the default device. 149 | Defaults to None. 150 | """ 151 | 152 | def __init__( 153 | self, 154 | block_duration: float = 0.5, 155 | device: Optional[Union[int, Text, Tuple[int, Text]]] = None, 156 | ): 157 | # Use the lowest supported sample rate 158 | sample_rates = [16000, 32000, 44100, 48000] 159 | best_sample_rate = None 160 | for sr in sample_rates: 161 | try: 162 | sd.check_input_settings(device=device, samplerate=sr) 163 | except Exception: 164 | pass 165 | else: 166 | best_sample_rate = sr 167 | break 168 | super().__init__(f"input_device:{device}", best_sample_rate) 169 | 170 | # Determine block size in samples and create input stream 171 | self.block_size = int(np.rint(block_duration * self.sample_rate)) 172 | self._mic_stream = sd.InputStream( 173 | channels=1, 174 | samplerate=self.sample_rate, 175 | latency=0, 176 | blocksize=self.block_size, 177 | callback=self._read_callback, 178 | device=device, 179 | ) 180 | self._queue = SimpleQueue() 181 | 182 | def _read_callback(self, samples, *args): 183 | self._queue.put_nowait(samples[:, [0]].T) 184 | 185 | def read(self): 186 | self._mic_stream.start() 187 | while self._mic_stream: 188 | try: 189 | while self._queue.empty(): 190 | if self._mic_stream.closed: 191 | break 192 | self.stream.on_next(self._queue.get_nowait()) 193 | except BaseException as e: 194 | self.stream.on_error(e) 195 | break 196 | self.stream.on_completed() 197 | self.close() 198 | 199 | def close(self): 200 | self._mic_stream.stop() 201 | self._mic_stream.close() 202 | 203 | 204 | class WebSocketAudioSource(AudioSource): 205 | """Represents a source of audio coming from the network using the WebSocket protocol. 206 | 207 | Parameters 208 | ---------- 209 | sample_rate: int 210 | Sample rate of the chunks emitted. 211 | host: Text 212 | The host to run the websocket server. 213 | Defaults to 127.0.0.1. 214 | port: int 215 | The port to run the websocket server. 216 | Defaults to 7007. 217 | key: Text | Path | None 218 | Path to a key if using SSL. 219 | Defaults to no key. 220 | certificate: Text | Path | None 221 | Path to a certificate if using SSL. 222 | Defaults to no certificate. 223 | """ 224 | 225 | def __init__( 226 | self, 227 | sample_rate: int, 228 | host: Text = "127.0.0.1", 229 | port: int = 7007, 230 | key: Optional[Union[Text, Path]] = None, 231 | certificate: Optional[Union[Text, Path]] = None, 232 | ): 233 | # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities. 234 | # I would prefer the client to send a JSON with data and sample rate, then resample if needed 235 | super().__init__(f"{host}:{port}", sample_rate) 236 | self.client: Optional[Dict[Text, Any]] = None 237 | self.server = WebsocketServer(host, port, key=key, cert=certificate) 238 | self.server.set_fn_message_received(self._on_message_received) 239 | 240 | def _on_message_received( 241 | self, 242 | client: Dict[Text, Any], 243 | server: WebsocketServer, 244 | message: AnyStr, 245 | ): 246 | # Only one client at a time is allowed 247 | if self.client is None or self.client["id"] != client["id"]: 248 | self.client = client 249 | # Send decoded audio to pipeline 250 | self.stream.on_next(utils.decode_audio(message)) 251 | 252 | def read(self): 253 | """Starts running the websocket server and listening for audio chunks""" 254 | self.server.run_forever() 255 | 256 | def close(self): 257 | """Close the websocket server""" 258 | if self.server is not None: 259 | self.stream.on_completed() 260 | self.server.shutdown_gracefully() 261 | 262 | def send(self, message: AnyStr): 263 | """Send a message through the current websocket. 264 | 265 | Parameters 266 | ---------- 267 | message: AnyStr 268 | Bytes or string to send. 269 | """ 270 | if len(message) > 0: 271 | self.server.send_message(self.client, message) 272 | 273 | 274 | class TorchStreamAudioSource(AudioSource): 275 | def __init__( 276 | self, 277 | uri: Text, 278 | sample_rate: int, 279 | streamer: StreamReader, 280 | stream_index: Optional[int] = None, 281 | block_duration: float = 0.5, 282 | ): 283 | super().__init__(uri, sample_rate) 284 | self.block_size = int(np.rint(block_duration * self.sample_rate)) 285 | self._streamer = streamer 286 | self._streamer.add_basic_audio_stream( 287 | frames_per_chunk=self.block_size, 288 | stream_index=stream_index, 289 | format="fltp", 290 | sample_rate=self.sample_rate, 291 | ) 292 | self.is_closed = False 293 | 294 | def read(self): 295 | for item in self._streamer.stream(): 296 | try: 297 | if self.is_closed: 298 | break 299 | # shape (samples, channels) to (1, samples) 300 | chunk = np.mean(item[0].numpy(), axis=1, keepdims=True).T 301 | self.stream.on_next(chunk) 302 | except BaseException as e: 303 | self.stream.on_error(e) 304 | break 305 | self.stream.on_completed() 306 | self.close() 307 | 308 | def close(self): 309 | self.is_closed = True 310 | 311 | 312 | class AppleDeviceAudioSource(TorchStreamAudioSource): 313 | def __init__( 314 | self, 315 | sample_rate: int, 316 | device: str = "0:0", 317 | stream_index: int = 0, 318 | block_duration: float = 0.5, 319 | ): 320 | uri = f"apple_input_device:{device}" 321 | streamer = StreamReader(device, format="avfoundation") 322 | super().__init__(uri, sample_rate, streamer, stream_index, block_duration) 323 | -------------------------------------------------------------------------------- /src/diart/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import time 3 | from typing import Optional, Text, Union 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook 8 | 9 | from . import blocks 10 | from .progress import ProgressBar 11 | 12 | 13 | class Chronometer: 14 | def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None): 15 | self.unit = unit 16 | self.progress_bar = progress_bar 17 | self.current_start_time = None 18 | self.history = [] 19 | 20 | @property 21 | def is_running(self): 22 | return self.current_start_time is not None 23 | 24 | def start(self): 25 | self.current_start_time = time.monotonic() 26 | 27 | def stop(self, do_count: bool = True): 28 | msg = "No start time available, Did you call stop() before start()?" 29 | assert self.current_start_time is not None, msg 30 | end_time = time.monotonic() - self.current_start_time 31 | self.current_start_time = None 32 | if do_count: 33 | self.history.append(end_time) 34 | 35 | def report(self): 36 | print_fn = print 37 | if self.progress_bar is not None: 38 | print_fn = self.progress_bar.write 39 | print_fn( 40 | f"Took {np.mean(self.history).item():.3f} " 41 | f"(+/-{np.std(self.history).item():.3f}) seconds/{self.unit} " 42 | f"-- ran {len(self.history)} times" 43 | ) 44 | 45 | 46 | def parse_hf_token_arg(hf_token: Union[bool, Text]) -> Union[bool, Text]: 47 | if isinstance(hf_token, bool): 48 | return hf_token 49 | if hf_token.lower() == "true": 50 | return True 51 | if hf_token.lower() == "false": 52 | return False 53 | return hf_token 54 | 55 | 56 | def encode_audio(waveform: np.ndarray) -> Text: 57 | data = waveform.astype(np.float32).tobytes() 58 | return base64.b64encode(data).decode("utf-8") 59 | 60 | 61 | def decode_audio(data: Text) -> np.ndarray: 62 | # Decode chunk encoded in base64 63 | byte_samples = base64.decodebytes(data.encode("utf-8")) 64 | # Recover array from bytes 65 | samples = np.frombuffer(byte_samples, dtype=np.float32) 66 | return samples.reshape(1, -1) 67 | 68 | 69 | def get_padding_left(stream_duration: float, chunk_duration: float) -> float: 70 | if stream_duration < chunk_duration: 71 | return chunk_duration - stream_duration 72 | return 0 73 | 74 | 75 | def repeat_label(label: Text): 76 | while True: 77 | yield label 78 | 79 | 80 | def get_pipeline_class(class_name: Text) -> type: 81 | pipeline_class = getattr(blocks, class_name, None) 82 | msg = f"Pipeline '{class_name}' doesn't exist" 83 | assert pipeline_class is not None, msg 84 | return pipeline_class 85 | 86 | 87 | def get_padding_right(latency: float, step: float) -> float: 88 | return latency - step 89 | 90 | 91 | def visualize_feature(duration: Optional[float] = None): 92 | def apply(feature: SlidingWindowFeature): 93 | if duration is None: 94 | notebook.crop = feature.extent 95 | else: 96 | notebook.crop = Segment(feature.extent.end - duration, feature.extent.end) 97 | plt.rcParams["figure.figsize"] = (8, 2) 98 | notebook.plot_feature(feature) 99 | plt.tight_layout() 100 | plt.show() 101 | 102 | return apply 103 | 104 | 105 | def visualize_annotation(duration: Optional[float] = None): 106 | def apply(annotation: Annotation): 107 | extent = annotation.get_timeline().extent() 108 | if duration is None: 109 | notebook.crop = extent 110 | else: 111 | notebook.crop = Segment(extent.end - duration, extent.end) 112 | plt.rcParams["figure.figsize"] = (8, 2) 113 | notebook.plot_annotation(annotation) 114 | plt.tight_layout() 115 | plt.show() 116 | 117 | return apply 118 | -------------------------------------------------------------------------------- /table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/table1.png -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import pytest 4 | import torch 5 | 6 | from diart.models import SegmentationModel, EmbeddingModel 7 | 8 | 9 | class DummySegmentationModel: 10 | def to(self, device): 11 | pass 12 | 13 | def __call__(self, waveform: torch.Tensor) -> torch.Tensor: 14 | assert waveform.ndim == 3 15 | 16 | batch_size, num_channels, num_samples = waveform.shape 17 | num_frames = random.randint(250, 500) 18 | num_speakers = random.randint(3, 5) 19 | 20 | return torch.rand(batch_size, num_frames, num_speakers) 21 | 22 | 23 | class DummyEmbeddingModel: 24 | def to(self, device): 25 | pass 26 | 27 | def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: 28 | assert waveform.ndim == 3 29 | assert weights.ndim == 2 30 | 31 | batch_size, num_channels, num_samples = waveform.shape 32 | batch_size_weights, num_frames = weights.shape 33 | 34 | assert batch_size == batch_size_weights 35 | 36 | embedding_dim = random.randint(128, 512) 37 | 38 | return torch.randn(batch_size, embedding_dim) 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def segmentation_model() -> SegmentationModel: 43 | return SegmentationModel(DummySegmentationModel) 44 | 45 | 46 | @pytest.fixture(scope="session") 47 | def embedding_model() -> EmbeddingModel: 48 | return EmbeddingModel(DummyEmbeddingModel) 49 | -------------------------------------------------------------------------------- /tests/data/audio/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/tests/data/audio/sample.wav -------------------------------------------------------------------------------- /tests/data/rttm/latency_0.5.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.675 0.533 speaker0 2 | SPEAKER sample 1 7.625 1.883 speaker0 3 | SPEAKER sample 1 9.508 1.000 speaker1 4 | SPEAKER sample 1 10.508 0.567 speaker0 5 | SPEAKER sample 1 10.625 4.133 speaker1 6 | SPEAKER sample 1 14.325 3.733 speaker0 7 | SPEAKER sample 1 18.058 3.450 speaker1 8 | SPEAKER sample 1 18.325 0.183 speaker0 9 | SPEAKER sample 1 21.508 0.017 speaker0 10 | SPEAKER sample 1 21.775 0.233 speaker1 11 | SPEAKER sample 1 22.008 6.633 speaker0 12 | SPEAKER sample 1 28.508 1.500 speaker1 13 | SPEAKER sample 1 29.958 0.050 speaker0 14 | -------------------------------------------------------------------------------- /tests/data/rttm/latency_1.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.708 0.450 speaker0 2 | SPEAKER sample 1 7.625 1.383 speaker0 3 | SPEAKER sample 1 9.008 1.500 speaker1 4 | SPEAKER sample 1 10.008 1.067 speaker0 5 | SPEAKER sample 1 10.592 4.200 speaker1 6 | SPEAKER sample 1 14.308 3.700 speaker0 7 | SPEAKER sample 1 18.042 3.250 speaker1 8 | SPEAKER sample 1 18.508 0.033 speaker0 9 | SPEAKER sample 1 21.108 0.383 speaker0 10 | SPEAKER sample 1 21.508 0.033 speaker1 11 | SPEAKER sample 1 21.775 6.817 speaker0 12 | SPEAKER sample 1 28.008 2.000 speaker1 13 | SPEAKER sample 1 29.975 0.033 speaker0 14 | -------------------------------------------------------------------------------- /tests/data/rttm/latency_2.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.725 0.433 speaker0 2 | SPEAKER sample 1 7.592 0.817 speaker0 3 | SPEAKER sample 1 8.475 1.617 speaker1 4 | SPEAKER sample 1 9.892 1.150 speaker0 5 | SPEAKER sample 1 10.625 4.133 speaker1 6 | SPEAKER sample 1 14.292 3.667 speaker0 7 | SPEAKER sample 1 18.008 3.533 speaker1 8 | SPEAKER sample 1 18.225 0.283 speaker0 9 | SPEAKER sample 1 21.758 6.867 speaker0 10 | SPEAKER sample 1 27.875 2.133 speaker1 11 | -------------------------------------------------------------------------------- /tests/data/rttm/latency_3.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.725 0.433 speaker0 2 | SPEAKER sample 1 7.625 0.467 speaker0 3 | SPEAKER sample 1 8.008 2.050 speaker1 4 | SPEAKER sample 1 9.875 1.167 speaker0 5 | SPEAKER sample 1 10.592 4.167 speaker1 6 | SPEAKER sample 1 14.292 3.667 speaker0 7 | SPEAKER sample 1 17.992 3.550 speaker1 8 | SPEAKER sample 1 18.192 0.367 speaker0 9 | SPEAKER sample 1 21.758 6.833 speaker0 10 | SPEAKER sample 1 27.825 2.183 speaker1 11 | -------------------------------------------------------------------------------- /tests/data/rttm/latency_4.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.742 0.400 speaker0 2 | SPEAKER sample 1 7.625 0.650 speaker0 3 | SPEAKER sample 1 8.092 1.950 speaker1 4 | SPEAKER sample 1 9.875 1.167 speaker0 5 | SPEAKER sample 1 10.575 4.183 speaker1 6 | SPEAKER sample 1 14.308 3.667 speaker0 7 | SPEAKER sample 1 17.992 3.550 speaker1 8 | SPEAKER sample 1 18.208 0.333 speaker0 9 | SPEAKER sample 1 21.758 6.817 speaker0 10 | SPEAKER sample 1 27.808 2.200 speaker1 11 | -------------------------------------------------------------------------------- /tests/data/rttm/latency_5.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.742 0.383 speaker0 2 | SPEAKER sample 1 7.625 0.667 speaker0 3 | SPEAKER sample 1 8.092 1.967 speaker1 4 | SPEAKER sample 1 9.875 1.167 speaker0 5 | SPEAKER sample 1 10.558 4.200 speaker1 6 | SPEAKER sample 1 14.308 3.667 speaker0 7 | SPEAKER sample 1 17.992 3.550 speaker1 8 | SPEAKER sample 1 18.208 0.317 speaker0 9 | SPEAKER sample 1 21.758 6.817 speaker0 10 | SPEAKER sample 1 27.808 2.200 speaker1 11 | -------------------------------------------------------------------------------- /tests/test_aggregation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pyannote.core import SlidingWindow, SlidingWindowFeature 4 | 5 | from diart.blocks.aggregation import ( 6 | AggregationStrategy, 7 | HammingWeightedAverageStrategy, 8 | FirstOnlyStrategy, 9 | AverageStrategy, 10 | DelayedAggregation, 11 | ) 12 | 13 | 14 | def test_strategy_build(): 15 | strategy = AggregationStrategy.build("mean") 16 | assert isinstance(strategy, AverageStrategy) 17 | 18 | strategy = AggregationStrategy.build("hamming") 19 | assert isinstance(strategy, HammingWeightedAverageStrategy) 20 | 21 | strategy = AggregationStrategy.build("first") 22 | assert isinstance(strategy, FirstOnlyStrategy) 23 | 24 | with pytest.raises(Exception): 25 | AggregationStrategy.build("invalid") 26 | 27 | 28 | def test_aggregation(): 29 | duration = 5 30 | frames = 500 31 | step = 0.5 32 | speakers = 2 33 | start_time = 10 34 | resolution = duration / frames 35 | 36 | dagg1 = DelayedAggregation(step=step, latency=2, strategy="mean") 37 | dagg2 = DelayedAggregation(step=step, latency=2, strategy="hamming") 38 | dagg3 = DelayedAggregation(step=step, latency=2, strategy="first") 39 | 40 | for dagg in [dagg1, dagg2, dagg3]: 41 | assert dagg.num_overlapping_windows == 4 42 | 43 | buffers = [ 44 | SlidingWindowFeature( 45 | np.random.rand(frames, speakers), 46 | SlidingWindow( 47 | start=(i + start_time) * step, duration=resolution, step=resolution 48 | ), 49 | ) 50 | for i in range(dagg1.num_overlapping_windows) 51 | ] 52 | 53 | for dagg in [dagg1, dagg2, dagg3]: 54 | assert dagg(buffers).data.shape == (51, 2) 55 | -------------------------------------------------------------------------------- /tests/test_diarization.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | 5 | import pytest 6 | 7 | from diart import SpeakerDiarizationConfig, SpeakerDiarization 8 | from utils import build_waveform_swf 9 | 10 | 11 | @pytest.fixture 12 | def random_diarization_config( 13 | segmentation_model, embedding_model 14 | ) -> SpeakerDiarizationConfig: 15 | duration = round(random.uniform(1, 10), 1) 16 | step = round(random.uniform(0.1, duration), 1) 17 | latency = round(random.uniform(step, duration), 1) 18 | return SpeakerDiarizationConfig( 19 | segmentation=segmentation_model, 20 | embedding=embedding_model, 21 | duration=duration, 22 | step=step, 23 | latency=latency, 24 | ) 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def min_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: 29 | return SpeakerDiarizationConfig( 30 | segmentation=segmentation_model, 31 | embedding=embedding_model, 32 | duration=5, 33 | step=0.5, 34 | latency="min", 35 | ) 36 | 37 | 38 | @pytest.fixture(scope="session") 39 | def max_latency_config(segmentation_model, embedding_model) -> SpeakerDiarizationConfig: 40 | return SpeakerDiarizationConfig( 41 | segmentation=segmentation_model, 42 | embedding=embedding_model, 43 | duration=5, 44 | step=0.5, 45 | latency="max", 46 | ) 47 | 48 | 49 | def test_config( 50 | segmentation_model, embedding_model, min_latency_config, max_latency_config 51 | ): 52 | duration = round(random.uniform(1, 10), 1) 53 | step = round(random.uniform(0.1, duration), 1) 54 | latency = round(random.uniform(step, duration), 1) 55 | config = SpeakerDiarizationConfig( 56 | segmentation=segmentation_model, 57 | embedding=embedding_model, 58 | duration=duration, 59 | step=step, 60 | latency=latency, 61 | ) 62 | 63 | assert config.duration == duration 64 | assert config.step == step 65 | assert config.latency == latency 66 | assert min_latency_config.latency == min_latency_config.step 67 | assert max_latency_config.latency == max_latency_config.duration 68 | 69 | 70 | def test_bad_latency(segmentation_model, embedding_model): 71 | duration = round(random.uniform(1, 10), 1) 72 | step = round(random.uniform(0.5, duration - 0.2), 1) 73 | latency_too_low = round(random.uniform(0, step - 0.1), 1) 74 | latency_too_high = round(random.uniform(duration + 0.1, 100), 1) 75 | 76 | config1 = SpeakerDiarizationConfig( 77 | segmentation=segmentation_model, 78 | embedding=embedding_model, 79 | duration=duration, 80 | step=step, 81 | latency=latency_too_low, 82 | ) 83 | config2 = SpeakerDiarizationConfig( 84 | segmentation=segmentation_model, 85 | embedding=embedding_model, 86 | duration=duration, 87 | step=step, 88 | latency=latency_too_high, 89 | ) 90 | 91 | with pytest.raises(AssertionError): 92 | SpeakerDiarization(config1) 93 | 94 | with pytest.raises(AssertionError): 95 | SpeakerDiarization(config2) 96 | 97 | 98 | def test_pipeline_build(random_diarization_config): 99 | pipeline = SpeakerDiarization(random_diarization_config) 100 | 101 | assert pipeline.get_config_class() == SpeakerDiarizationConfig 102 | 103 | hparams = pipeline.hyper_parameters() 104 | hp_names = [hp.name for hp in hparams] 105 | assert len(set(hp_names)) == 3 106 | 107 | for hparam in hparams: 108 | assert hparam.low == 0 109 | if hparam.name in ["tau_active", "rho_update"]: 110 | assert hparam.high == 1 111 | elif hparam.name == "delta_new": 112 | assert hparam.high == 2 113 | else: 114 | assert False 115 | 116 | assert pipeline.config == random_diarization_config 117 | 118 | 119 | def test_timestamp_shift(random_diarization_config): 120 | pipeline = SpeakerDiarization(random_diarization_config) 121 | 122 | assert pipeline.timestamp_shift == 0 123 | 124 | new_shift = round(random.uniform(-10, 10), 1) 125 | pipeline.set_timestamp_shift(new_shift) 126 | assert pipeline.timestamp_shift == new_shift 127 | 128 | waveform = build_waveform_swf( 129 | random_diarization_config.duration, 130 | random_diarization_config.sample_rate, 131 | ) 132 | prediction, _ = pipeline([waveform])[0] 133 | 134 | for segment, _, label in prediction.itertracks(yield_label=True): 135 | assert segment.start >= new_shift 136 | assert segment.end >= new_shift 137 | 138 | pipeline.reset() 139 | assert pipeline.timestamp_shift == 0 140 | 141 | 142 | def test_call_min_latency(min_latency_config): 143 | pipeline = SpeakerDiarization(min_latency_config) 144 | waveform1 = build_waveform_swf( 145 | min_latency_config.duration, 146 | min_latency_config.sample_rate, 147 | start_time=0, 148 | ) 149 | waveform2 = build_waveform_swf( 150 | min_latency_config.duration, 151 | min_latency_config.sample_rate, 152 | min_latency_config.step, 153 | ) 154 | 155 | batch = [waveform1, waveform2] 156 | output = pipeline(batch) 157 | 158 | pred1, wave1 = output[0] 159 | pred2, wave2 = output[1] 160 | 161 | assert waveform1.data.shape[0] == wave1.data.shape[0] 162 | assert wave1.data.shape[0] > wave2.data.shape[0] 163 | 164 | pred1_timeline = pred1.get_timeline() 165 | pred2_timeline = pred2.get_timeline() 166 | pred1_duration = round(pred1_timeline[-1].end - pred1_timeline[0].start, 3) 167 | pred2_duration = round(pred2_timeline[-1].end - pred2_timeline[0].start, 3) 168 | 169 | expected_duration = round(min_latency_config.duration, 3) 170 | expected_step = round(min_latency_config.step, 3) 171 | assert not pred1_timeline or pred1_duration <= expected_duration 172 | assert not pred2_timeline or pred2_duration <= expected_step 173 | 174 | 175 | def test_call_max_latency(max_latency_config): 176 | pipeline = SpeakerDiarization(max_latency_config) 177 | waveform1 = build_waveform_swf( 178 | max_latency_config.duration, 179 | max_latency_config.sample_rate, 180 | start_time=0, 181 | ) 182 | waveform2 = build_waveform_swf( 183 | max_latency_config.duration, 184 | max_latency_config.sample_rate, 185 | max_latency_config.step, 186 | ) 187 | 188 | batch = [waveform1, waveform2] 189 | output = pipeline(batch) 190 | 191 | pred1, wave1 = output[0] 192 | pred2, wave2 = output[1] 193 | 194 | assert waveform1.data.shape[0] > wave1.data.shape[0] 195 | assert wave1.data.shape[0] == wave2.data.shape[0] 196 | 197 | pred1_timeline = pred1.get_timeline() 198 | pred2_timeline = pred2.get_timeline() 199 | pred1_duration = pred1_timeline[-1].end - pred1_timeline[0].start 200 | pred2_duration = pred2_timeline[-1].end - pred2_timeline[0].start 201 | 202 | expected_step = round(max_latency_config.step, 3) 203 | assert not pred1_timeline or round(pred1_duration, 3) <= expected_step 204 | assert not pred2_timeline or round(pred2_duration, 3) <= expected_step 205 | -------------------------------------------------------------------------------- /tests/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | 4 | import pytest 5 | from pyannote.database.util import load_rttm 6 | 7 | from diart import SpeakerDiarization, SpeakerDiarizationConfig 8 | from diart.inference import StreamingInference 9 | from diart.models import SegmentationModel, EmbeddingModel 10 | from diart.sources import FileAudioSource 11 | 12 | MODEL_DIR = Path(__file__).parent.parent / "assets" / "models" 13 | DATA_DIR = Path(__file__).parent / "data" 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def segmentation(): 18 | model_path = MODEL_DIR / "segmentation_uint8.onnx" 19 | return SegmentationModel.from_pretrained(model_path) 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def embedding(): 24 | model_path = MODEL_DIR / "embedding_uint8.onnx" 25 | return EmbeddingModel.from_pretrained(model_path) 26 | 27 | 28 | @pytest.fixture(scope="session") 29 | def make_config(segmentation, embedding): 30 | def _config(latency): 31 | return SpeakerDiarizationConfig( 32 | segmentation=segmentation, 33 | embedding=embedding, 34 | step=0.5, 35 | latency=latency, 36 | tau_active=0.507, 37 | rho_update=0.006, 38 | delta_new=1.057 39 | ) 40 | return _config 41 | 42 | 43 | @pytest.mark.parametrize("source_file", [DATA_DIR / "audio" / "sample.wav"]) 44 | @pytest.mark.parametrize("latency", [0.5, 1, 2, 3, 4, 5]) 45 | def test_benchmark(make_config, source_file, latency): 46 | config = make_config(latency) 47 | pipeline = SpeakerDiarization(config) 48 | 49 | padding = pipeline.config.get_file_padding(source_file) 50 | source = FileAudioSource( 51 | source_file, 52 | pipeline.config.sample_rate, 53 | padding, 54 | pipeline.config.step, 55 | ) 56 | 57 | pipeline.set_timestamp_shift(-padding[0]) 58 | inference = StreamingInference( 59 | pipeline, 60 | source, 61 | do_profile=False, 62 | do_plot=False, 63 | show_progress=False 64 | ) 65 | 66 | pred = inference() 67 | 68 | expected_file = (DATA_DIR / "rttm" / f"latency_{latency}.rttm") 69 | expected = load_rttm(expected_file).popitem()[1] 70 | 71 | assert len(pred) == len(expected) 72 | for track1, track2 in zip(pred.itertracks(yield_label=True), expected.itertracks(yield_label=True)): 73 | pred_segment, _, pred_spk = track1 74 | expected_segment, _, expected_spk = track2 75 | # We can tolerate a difference of up to 50ms 76 | assert math.isclose(pred_segment.start, expected_segment.start, abs_tol=0.05) 77 | assert math.isclose(pred_segment.end, expected_segment.end, abs_tol=0.05) 78 | assert pred_spk == expected_spk 79 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import random 3 | import numpy as np 4 | from pyannote.core import SlidingWindowFeature, SlidingWindow 5 | 6 | 7 | def build_waveform_swf( 8 | duration: float, sample_rate: int, start_time: float | None = None 9 | ) -> SlidingWindowFeature: 10 | start_time = round(random.uniform(0, 600), 1) if start_time is None else start_time 11 | chunk_size = int(duration * sample_rate) 12 | resolution = duration / chunk_size 13 | samples = np.random.randn(chunk_size, 1) 14 | sliding_window = SlidingWindow( 15 | start=start_time, step=resolution, duration=resolution 16 | ) 17 | return SlidingWindowFeature(samples, sliding_window) 18 | --------------------------------------------------------------------------------