├── .github
├── FUNDING.yml
└── workflows
│ ├── lint.yml
│ ├── pytest.yml
│ ├── python-publish.yml
│ └── quick-runs.yml
├── .gitignore
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
└── models
│ ├── embedding_uint8.onnx
│ └── segmentation_uint8.onnx
├── demo.gif
├── docs
├── Makefile
├── _static
│ └── logo.png
├── conf.py
├── index.rst
├── make.bat
└── requirements.txt
├── environment.yml
├── expected_outputs
├── offline
│ ├── vbx+overlap_aware_segmentation
│ │ ├── AMI.rttm
│ │ ├── DIHARD2.rttm
│ │ ├── DIHARD3.rttm
│ │ └── VoxConverse.rttm
│ └── vbx
│ │ ├── AMI.rttm
│ │ ├── DIHARD2.rttm
│ │ ├── DIHARD3.rttm
│ │ └── VoxConverse.rttm
└── online
│ ├── 0.5s
│ ├── AMI.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 1.0s
│ ├── AMI.rttm
│ ├── DIHARD2_reoptimized.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 2.0s
│ ├── AMI.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 3.0s
│ ├── AMI.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 4.0s
│ ├── AMI.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 5.0s+oracle_segmentation
│ ├── AMI.rttm
│ ├── DIHARD2.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ ├── 5.0s-overlap_aware_embeddings
│ ├── AMI.rttm
│ ├── DIHARD2.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
│ └── 5.0s
│ ├── AMI.rttm
│ ├── DIHARD2.rttm
│ ├── DIHARD3.rttm
│ └── VoxConverse.rttm
├── figure1.png
├── figure5.png
├── logo.jpg
├── paper.pdf
├── pipeline.gif
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
├── src
└── diart
│ ├── __init__.py
│ ├── argdoc.py
│ ├── audio.py
│ ├── blocks
│ ├── __init__.py
│ ├── aggregation.py
│ ├── base.py
│ ├── clustering.py
│ ├── diarization.py
│ ├── embedding.py
│ ├── segmentation.py
│ ├── utils.py
│ └── vad.py
│ ├── console
│ ├── __init__.py
│ ├── benchmark.py
│ ├── client.py
│ ├── serve.py
│ ├── stream.py
│ └── tune.py
│ ├── features.py
│ ├── functional.py
│ ├── inference.py
│ ├── mapping.py
│ ├── models.py
│ ├── operators.py
│ ├── optim.py
│ ├── progress.py
│ ├── sinks.py
│ ├── sources.py
│ └── utils.py
├── table1.png
└── tests
├── conftest.py
├── data
├── audio
│ └── sample.wav
└── rttm
│ ├── latency_0.5.rttm
│ ├── latency_1.rttm
│ ├── latency_2.rttm
│ ├── latency_3.rttm
│ ├── latency_4.rttm
│ └── latency_5.rttm
├── test_aggregation.py
├── test_diarization.py
├── test_end_to_end.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [juanmc2005] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | # patreon: # Replace with a single Patreon username
5 | # open_collective: # Replace with a single Open Collective username
6 | # ko_fi: # Replace with a single Ko-fi username
7 | # tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | # community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | # liberapay: # Replace with a single Liberapay username
10 | # issuehunt: # Replace with a single IssueHunt username
11 | # otechie: # Replace with a single Otechie username
12 | # lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | # custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | - develop
8 |
9 | jobs:
10 | lint:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v3
14 | - uses: psf/black@stable
15 | with:
16 | options: "--check --verbose"
17 | src: "./src/diart"
18 | version: "23.10.1"
19 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: Pytest
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | - develop
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 |
13 | strategy:
14 | matrix:
15 | python-version: ["3.10", "3.11", "3.12"]
16 |
17 | steps:
18 | - name: Checkout code
19 | uses: actions/checkout@v3
20 |
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v3
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 |
26 | - name: Install apt dependencies
27 | run: |
28 | sudo apt-get update
29 | sudo apt-get -y install ffmpeg libportaudio2
30 |
31 | - name: Install pip dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install .[tests]
35 |
36 | - name: Run tests
37 | run: |
38 | pytest
39 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.github/workflows/quick-runs.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Quick Runs
5 |
6 | on:
7 | pull_request:
8 | branches: [ "main", "develop" ]
9 |
10 | jobs:
11 | build:
12 |
13 | runs-on: ubuntu-latest
14 | strategy:
15 | fail-fast: false
16 | matrix:
17 | python-version: ["3.10", "3.11", "3.12"]
18 |
19 | steps:
20 | - uses: actions/checkout@v3
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v3
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Download data
26 | run: |
27 | mkdir audio rttms trash
28 | wget --no-verbose --show-progress --continue -O audio/ES2002a_long.wav http://groups.inf.ed.ac.uk/ami/AMICorpusMirror/amicorpus/ES2002a/audio/ES2002a.Mix-Headset.wav
29 | wget --no-verbose --show-progress --continue -O audio/ES2002b_long.wav http://groups.inf.ed.ac.uk/ami/AMICorpusMirror/amicorpus/ES2002b/audio/ES2002b.Mix-Headset.wav
30 | wget --no-verbose --show-progress --continue -O rttms/ES2002a_long.rttm https://raw.githubusercontent.com/pyannote/AMI-diarization-setup/main/only_words/rttms/train/ES2002a.rttm
31 | wget --no-verbose --show-progress --continue -O rttms/ES2002b_long.rttm https://raw.githubusercontent.com/pyannote/AMI-diarization-setup/main/only_words/rttms/train/ES2002b.rttm
32 | - name: Install apt dependencies
33 | run: |
34 | sudo apt-get update
35 | sudo apt-get -y install ffmpeg libportaudio2 sox
36 | - name: Install pip dependencies
37 | run: |
38 | python -m pip install --upgrade pip
39 | pip install .
40 | pip install onnxruntime==1.18.0
41 | - name: Crop audio and rttm
42 | run: |
43 | sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30
44 | sox audio/ES2002b_long.wav audio/ES2002b.wav trim 00:10 00:30
45 | head -n 4 rttms/ES2002a_long.rttm > rttms/ES2002a.rttm
46 | head -n 7 rttms/ES2002b_long.rttm > rttms/ES2002b.rttm
47 | rm audio/ES2002a_long.wav
48 | rm audio/ES2002b_long.wav
49 | rm rttms/ES2002a_long.rttm
50 | rm rttms/ES2002b_long.rttm
51 | - name: Run stream
52 | run: |
53 | diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot
54 | - name: Run benchmark
55 | run: |
56 | diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
57 | - name: Run tuning
58 | run: |
59 | diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all
4 |
5 | ### PyCharm+all ###
6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
8 |
9 | # User-specific stuff
10 | .idea/**/workspace.xml
11 | .idea/**/tasks.xml
12 | .idea/**/usage.statistics.xml
13 | .idea/**/dictionaries
14 | .idea/**/shelf
15 |
16 | # AWS User-specific
17 | .idea/**/aws.xml
18 |
19 | # Generated files
20 | .idea/**/contentModel.xml
21 |
22 | # Sensitive or high-churn files
23 | .idea/**/dataSources/
24 | .idea/**/dataSources.ids
25 | .idea/**/dataSources.local.xml
26 | .idea/**/sqlDataSources.xml
27 | .idea/**/dynamic.xml
28 | .idea/**/uiDesigner.xml
29 | .idea/**/dbnavigator.xml
30 |
31 | # Gradle
32 | .idea/**/gradle.xml
33 | .idea/**/libraries
34 |
35 | # Gradle and Maven with auto-import
36 | # When using Gradle or Maven with auto-import, you should exclude module files,
37 | # since they will be recreated, and may cause churn. Uncomment if using
38 | # auto-import.
39 | # .idea/artifacts
40 | # .idea/compiler.xml
41 | # .idea/jarRepositories.xml
42 | # .idea/modules.xml
43 | # .idea/*.iml
44 | # .idea/modules
45 | # *.iml
46 | # *.ipr
47 |
48 | # CMake
49 | cmake-build-*/
50 |
51 | # Mongo Explorer plugin
52 | .idea/**/mongoSettings.xml
53 |
54 | # File-based project format
55 | *.iws
56 |
57 | # IntelliJ
58 | out/
59 |
60 | # mpeltonen/sbt-idea plugin
61 | .idea_modules/
62 |
63 | # JIRA plugin
64 | atlassian-ide-plugin.xml
65 |
66 | # Cursive Clojure plugin
67 | .idea/replstate.xml
68 |
69 | # Crashlytics plugin (for Android Studio and IntelliJ)
70 | com_crashlytics_export_strings.xml
71 | crashlytics.properties
72 | crashlytics-build.properties
73 | fabric.properties
74 |
75 | # Editor-based Rest Client
76 | .idea/httpRequests
77 |
78 | # Android studio 3.1+ serialized cache file
79 | .idea/caches/build_file_checksums.ser
80 |
81 | ### PyCharm+all Patch ###
82 | # Ignores the whole .idea folder and all .iml files
83 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
84 |
85 | .idea/
86 |
87 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
88 |
89 | *.iml
90 | modules.xml
91 | .idea/misc.xml
92 | *.ipr
93 |
94 | # Sonarlint plugin
95 | .idea/sonarlint
96 |
97 | ### Python ###
98 | # Byte-compiled / optimized / DLL files
99 | __pycache__/
100 | *.py[cod]
101 | *$py.class
102 |
103 | # C extensions
104 | *.so
105 |
106 | # Distribution / packaging
107 | .Python
108 | build/
109 | develop-eggs/
110 | dist/
111 | downloads/
112 | eggs/
113 | .eggs/
114 | lib/
115 | lib64/
116 | parts/
117 | sdist/
118 | var/
119 | wheels/
120 | share/python-wheels/
121 | *.egg-info/
122 | .installed.cfg
123 | *.egg
124 | MANIFEST
125 |
126 | # PyInstaller
127 | # Usually these files are written by a python script from a template
128 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
129 | *.manifest
130 | *.spec
131 |
132 | # Installer logs
133 | pip-log.txt
134 | pip-delete-this-directory.txt
135 |
136 | # Unit test / coverage reports
137 | htmlcov/
138 | .tox/
139 | .nox/
140 | .coverage
141 | .coverage.*
142 | .cache
143 | nosetests.xml
144 | coverage.xml
145 | *.cover
146 | *.py,cover
147 | .hypothesis/
148 | .pytest_cache/
149 | cover/
150 |
151 | # Translations
152 | *.mo
153 | *.pot
154 |
155 | # Django stuff:
156 | *.log
157 | local_settings.py
158 | db.sqlite3
159 | db.sqlite3-journal
160 |
161 | # Flask stuff:
162 | instance/
163 | .webassets-cache
164 |
165 | # Scrapy stuff:
166 | .scrapy
167 |
168 | # Sphinx documentation
169 | docs/_build/
170 |
171 | # PyBuilder
172 | .pybuilder/
173 | target/
174 |
175 | # Jupyter Notebook
176 | .ipynb_checkpoints
177 |
178 | # IPython
179 | profile_default/
180 | ipython_config.py
181 |
182 | # pyenv
183 | # For a library or package, you might want to ignore these files since the code is
184 | # intended to run in multiple environments; otherwise, check them in:
185 | # .python-version
186 |
187 | # pipenv
188 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
189 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
190 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
191 | # install all needed dependencies.
192 | #Pipfile.lock
193 |
194 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
195 | __pypackages__/
196 |
197 | # Celery stuff
198 | celerybeat-schedule
199 | celerybeat.pid
200 |
201 | # SageMath parsed files
202 | *.sage.py
203 |
204 | # Environments
205 | .env
206 | .venv
207 | env/
208 | venv/
209 | ENV/
210 | env.bak/
211 | venv.bak/
212 |
213 | # Spyder project settings
214 | .spyderproject
215 | .spyproject
216 |
217 | # Rope project settings
218 | .ropeproject
219 |
220 | # mkdocs documentation
221 | /site
222 |
223 | # mypy
224 | .mypy_cache/
225 | .dmypy.json
226 | dmypy.json
227 |
228 | # Pyre type checker
229 | .pyre/
230 |
231 | # pytype static type analyzer
232 | .pytype/
233 |
234 | # Cython debug symbols
235 | cython_debug/
236 |
237 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all
238 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.10"
7 |
8 | python:
9 | install:
10 | - requirements: docs/requirements.txt
11 | # Install diart before building the docs
12 | - method: pip
13 | path: .
14 |
15 | sphinx:
16 | configuration: docs/conf.py
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to diart
2 |
3 | Thank you for considering contributing to diart! We appreciate your time and effort to help make this project better.
4 |
5 | ## Before You Start
6 |
7 | 1. **Search for Existing Issues or Discussions:**
8 | - Before opening a new issue or discussion, please check if there's already an existing one related to your topic. This helps avoid duplicates and keeps discussions centralized.
9 |
10 | 2. **Discuss Your Contribution:**
11 | - If you plan to make a significant change, it's advisable to discuss it in an issue first. This ensures that your contribution aligns with the project's goals and avoids duplicated efforts.
12 |
13 | 3. **Questions about diart:**
14 | - For general questions about diart, use the discussion space on GitHub. This helps in fostering a collaborative environment and encourages knowledge-sharing.
15 |
16 | ## Opening Issues
17 |
18 | If you encounter a problem with diart or want to suggest an improvement, please follow these guidelines when opening an issue:
19 |
20 | - **Bug Reports:**
21 | - Clearly describe the error, including any relevant stack traces.
22 | - Provide a minimal, reproducible example that demonstrates the issue.
23 | - Mention the version of diart you are using (as well as any dependencies related to the bug).
24 |
25 | - **Feature Requests:**
26 | - Clearly outline the new feature you are proposing.
27 | - Explain how it would benefit the project.
28 |
29 | ## Opening Pull Requests
30 |
31 | We welcome and appreciate contributions! To ensure a smooth review process, please follow these guidelines when opening a pull request:
32 |
33 | - **Create a Branch:**
34 | - Work on your changes in a dedicated branch created from `develop`.
35 |
36 | - **Commit Messages:**
37 | - Write clear and concise commit messages, explaining the purpose of each change.
38 |
39 | - **Documentation:**
40 | - Update documentation when introducing new features or making changes that impact existing functionality.
41 |
42 | - **Tests:**
43 | - If applicable, add or update tests to cover your changes.
44 |
45 | - **Code Style:**
46 | - Follow the existing coding style of the project. We use `black` and `isort`.
47 |
48 | - **Discuss Before Major Changes:**
49 | - If your PR includes significant changes, discuss it in an issue first.
50 |
51 | - **Follow the existing workflow:**
52 | - Make sure to open your PR against `develop` (**not** `main`).
53 |
54 | ## Thank You
55 |
56 | Your contributions make diart better for everyone. Thank you for your time and dedication!
57 |
58 | Happy coding!
59 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Université Paris-Saclay
4 | Copyright (c) 2021 CNRS
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | 🌿 Build AI-powered real-time audio applications in a breeze 🌿
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 | ## ⚡ Quick introduction
57 |
58 | Diart is a python framework to build AI-powered real-time audio applications.
59 | Its key feature is the ability to recognize different speakers in real time with state-of-the-art performance,
60 | a task commonly known as "speaker diarization".
61 |
62 | The pipeline `diart.SpeakerDiarization` combines a speaker segmentation and a speaker embedding model
63 | to power an incremental clustering algorithm that gets more accurate as the conversation progresses:
64 |
65 |
66 |
67 |
68 |
69 | With diart you can also create your own custom AI pipeline, benchmark it,
70 | tune its hyper-parameters, and even serve it on the web using websockets.
71 |
72 | **We provide pre-trained pipelines for:**
73 |
74 | - Speaker Diarization
75 | - Voice Activity Detection
76 | - Transcription ([coming soon](https://github.com/juanmc2005/diart/pull/144))
77 | - [Speaker-Aware Transcription](https://betterprogramming.pub/color-your-captions-streamlining-live-transcriptions-with-diart-and-openais-whisper-6203350234ef) ([coming soon](https://github.com/juanmc2005/diart/pull/147))
78 |
79 | ## 💾 Installation
80 |
81 | **1) Make sure your system has the following dependencies:**
82 |
83 | ```
84 | ffmpeg < 4.4
85 | portaudio == 19.6.X
86 | libsndfile >= 1.2.2
87 | ```
88 |
89 | Alternatively, we provide an `environment.yml` file for a pre-configured conda environment:
90 |
91 | ```shell
92 | conda env create -f diart/environment.yml
93 | conda activate diart
94 | ```
95 |
96 | **2) Install the package:**
97 | ```shell
98 | pip install diart
99 | ```
100 |
101 | ### Get access to 🎹 pyannote models
102 |
103 | By default, diart is based on [pyannote.audio](https://github.com/pyannote/pyannote-audio) models from the [huggingface](https://huggingface.co/) hub.
104 | In order to use them, please follow these steps:
105 |
106 | 1) [Accept user conditions](https://huggingface.co/pyannote/segmentation) for the `pyannote/segmentation` model
107 | 2) [Accept user conditions](https://huggingface.co/pyannote/segmentation-3.0) for the newest `pyannote/segmentation-3.0` model
108 | 3) [Accept user conditions](https://huggingface.co/pyannote/embedding) for the `pyannote/embedding` model
109 | 4) Install [huggingface-cli](https://huggingface.co/docs/huggingface_hub/quick-start#install-the-hub-library) and [log in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with your user access token (or provide it manually in diart CLI or API).
110 |
111 | ## 🎙️ Stream audio
112 |
113 | ### From the command line
114 |
115 | A recorded conversation:
116 |
117 | ```shell
118 | diart.stream /path/to/audio.wav
119 | ```
120 |
121 | A live conversation:
122 |
123 | ```shell
124 | # Use "microphone:ID" to select a non-default device
125 | # See `python -m sounddevice` for available devices
126 | diart.stream microphone
127 | ```
128 |
129 | By default, diart runs a speaker diarization pipeline, equivalent to setting `--pipeline SpeakerDiarization`,
130 | but you can also set it to `--pipeline VoiceActivityDetection`. See `diart.stream -h` for more options.
131 |
132 | ### From python
133 |
134 | Use `StreamingInference` to run a pipeline on an audio source and write the results to disk:
135 |
136 | ```python
137 | from diart import SpeakerDiarization
138 | from diart.sources import MicrophoneAudioSource
139 | from diart.inference import StreamingInference
140 | from diart.sinks import RTTMWriter
141 |
142 | pipeline = SpeakerDiarization()
143 | mic = MicrophoneAudioSource()
144 | inference = StreamingInference(pipeline, mic, do_plot=True)
145 | inference.attach_observers(RTTMWriter(mic.uri, "/output/file.rttm"))
146 | prediction = inference()
147 | ```
148 |
149 | For inference and evaluation on a dataset we recommend to use `Benchmark` (see notes on [reproducibility](#reproducibility)).
150 |
151 | ## 🧠 Models
152 |
153 | You can use other models with the `--segmentation` and `--embedding` arguments.
154 | Or in python:
155 |
156 | ```python
157 | import diart.models as m
158 |
159 | segmentation = m.SegmentationModel.from_pretrained("model_name")
160 | embedding = m.EmbeddingModel.from_pretrained("model_name")
161 | ```
162 |
163 | ### Pre-trained models
164 |
165 | Below is a list of all the models currently supported by diart:
166 |
167 | | Model Name | Model Type | CPU Time* | GPU Time* |
168 | |---------------------------------------------------------------------------------------------------------------------------|--------------|-----------|-----------|
169 | | [🤗](https://huggingface.co/pyannote/segmentation) `pyannote/segmentation` (default) | segmentation | 12ms | 8ms |
170 | | [🤗](https://huggingface.co/pyannote/segmentation-3.0) `pyannote/segmentation-3.0` | segmentation | 11ms | 8ms |
171 | | [🤗](https://huggingface.co/pyannote/embedding) `pyannote/embedding` (default) | embedding | 26ms | 12ms |
172 | | [🤗](https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM) `hbredin/wespeaker-voxceleb-resnet34-LM` (ONNX) | embedding | 48ms | 15ms |
173 | | [🤗](https://huggingface.co/pyannote/wespeaker-voxceleb-resnet34-LM) `pyannote/wespeaker-voxceleb-resnet34-LM` (PyTorch) | embedding | 150ms | 29ms |
174 | | [🤗](https://huggingface.co/speechbrain/spkrec-xvect-voxceleb) `speechbrain/spkrec-xvect-voxceleb` | embedding | 41ms | 15ms |
175 | | [🤗](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb) `speechbrain/spkrec-ecapa-voxceleb` | embedding | 41ms | 14ms |
176 | | [🤗](https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb-mel-spec) `speechbrain/spkrec-ecapa-voxceleb-mel-spec` | embedding | 42ms | 14ms |
177 | | [🤗](https://huggingface.co/speechbrain/spkrec-resnet-voxceleb) `speechbrain/spkrec-resnet-voxceleb` | embedding | 41ms | 16ms |
178 | | [🤗](https://huggingface.co/nvidia/speakerverification_en_titanet_large) `nvidia/speakerverification_en_titanet_large` | embedding | 91ms | 16ms |
179 |
180 | The latency of segmentation models is measured in a VAD pipeline (5s chunks).
181 |
182 | The latency of embedding models is measured in a diarization pipeline using `pyannote/segmentation` (also 5s chunks).
183 |
184 | \* CPU: AMD Ryzen 9 - GPU: RTX 4060 Max-Q
185 |
186 | ### Custom models
187 |
188 | Third-party models can be integrated by providing a loader function:
189 |
190 | ```python
191 | from diart import SpeakerDiarization, SpeakerDiarizationConfig
192 | from diart.models import EmbeddingModel, SegmentationModel
193 |
194 | def segmentation_loader():
195 | # It should take a waveform and return a segmentation tensor
196 | return load_pretrained_model("my_model.ckpt")
197 |
198 | def embedding_loader():
199 | # It should take (waveform, weights) and return per-speaker embeddings
200 | return load_pretrained_model("my_other_model.ckpt")
201 |
202 | segmentation = SegmentationModel(segmentation_loader)
203 | embedding = EmbeddingModel(embedding_loader)
204 | config = SpeakerDiarizationConfig(
205 | segmentation=segmentation,
206 | embedding=embedding,
207 | )
208 | pipeline = SpeakerDiarization(config)
209 | ```
210 |
211 | If you have an ONNX model, you can use `from_onnx()`:
212 |
213 | ```python
214 | from diart.models import EmbeddingModel
215 |
216 | embedding = EmbeddingModel.from_onnx(
217 | model_path="my_model.ckpt",
218 | input_names=["x", "w"], # defaults to ["waveform", "weights"]
219 | output_name="output", # defaults to "embedding"
220 | )
221 | ```
222 |
223 | ## 📈 Tune hyper-parameters
224 |
225 | Diart implements an optimizer based on [optuna](https://optuna.readthedocs.io/en/stable/index.html) that allows you to tune pipeline hyper-parameters to your needs.
226 |
227 | ### From the command line
228 |
229 | ```shell
230 | diart.tune /wav/dir --reference /rttm/dir --output /output/dir
231 | ```
232 |
233 | See `diart.tune -h` for more options.
234 |
235 | ### From python
236 |
237 | ```python
238 | from diart.optim import Optimizer
239 |
240 | optimizer = Optimizer("/wav/dir", "/rttm/dir", "/output/dir")
241 | optimizer(num_iter=100)
242 | ```
243 |
244 | This will write results to an sqlite database in `/output/dir`.
245 |
246 | ### Distributed tuning
247 |
248 | For bigger datasets, it is sometimes more convenient to run multiple optimization processes in parallel.
249 | To do this, create a study on a [recommended DBMS](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/004_distributed.html#sphx-glr-tutorial-10-key-features-004-distributed-py) (e.g. MySQL or PostgreSQL) making sure that the study and database names match:
250 |
251 | ```shell
252 | mysql -u root -e "CREATE DATABASE IF NOT EXISTS example"
253 | optuna create-study --study-name "example" --storage "mysql://root@localhost/example"
254 | ```
255 |
256 | You can now run multiple identical optimizers pointing to this database:
257 |
258 | ```shell
259 | diart.tune /wav/dir --reference /rttm/dir --storage mysql://root@localhost/example
260 | ```
261 |
262 | or in python:
263 |
264 | ```python
265 | from diart.optim import Optimizer
266 | from optuna.samplers import TPESampler
267 | import optuna
268 |
269 | db = "mysql://root@localhost/example"
270 | study = optuna.load_study("example", db, TPESampler())
271 | optimizer = Optimizer("/wav/dir", "/rttm/dir", study)
272 | optimizer(num_iter=100)
273 | ```
274 |
275 | ## 🧠🔗 Build pipelines
276 |
277 | For a more advanced usage, diart also provides building blocks that can be combined to create your own pipeline.
278 | Streaming is powered by [RxPY](https://github.com/ReactiveX/RxPY), but the `blocks` module is completely independent and can be used separately.
279 |
280 | ### Example
281 |
282 | Obtain overlap-aware speaker embeddings from a microphone stream:
283 |
284 | ```python
285 | import rx.operators as ops
286 | import diart.operators as dops
287 | from diart.sources import MicrophoneAudioSource, FileAudioSource
288 | from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding
289 |
290 | segmentation = SpeakerSegmentation.from_pretrained("pyannote/segmentation")
291 | embedding = OverlapAwareSpeakerEmbedding.from_pretrained("pyannote/embedding")
292 |
293 | source = MicrophoneAudioSource()
294 | # To take input from file:
295 | # source = FileAudioSource("", sample_rate=16000)
296 |
297 | # Make sure the models have been trained with this sample rate
298 | print(source.sample_rate)
299 |
300 | stream = mic.stream.pipe(
301 | # Reformat stream to 5s duration and 500ms shift
302 | dops.rearrange_audio_stream(sample_rate=source.sample_rate),
303 | ops.map(lambda wav: (wav, segmentation(wav))),
304 | ops.starmap(embedding)
305 | ).subscribe(on_next=lambda emb: print(emb.shape))
306 |
307 | source.read()
308 | ```
309 |
310 | Output:
311 |
312 | ```
313 | # Shape is (batch_size, num_speakers, embedding_dim)
314 | torch.Size([1, 3, 512])
315 | torch.Size([1, 3, 512])
316 | torch.Size([1, 3, 512])
317 | ...
318 | ```
319 |
320 | ## 🌐 WebSockets
321 |
322 | Diart is also compatible with the WebSocket protocol to serve pipelines on the web.
323 |
324 | ### From the command line
325 |
326 | ```shell
327 | diart.serve --host 0.0.0.0 --port 7007
328 | diart.client microphone --host --port 7007
329 | ```
330 |
331 | **Note:** make sure that the client uses the same `step` and `sample_rate` than the server with `--step` and `-sr`.
332 |
333 | See `-h` for more options.
334 |
335 | ### From python
336 |
337 | For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
338 |
339 | ```python
340 | from diart import SpeakerDiarization
341 | from diart.sources import WebSocketAudioSource
342 | from diart.inference import StreamingInference
343 |
344 | pipeline = SpeakerDiarization()
345 | source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
346 | inference = StreamingInference(pipeline, source)
347 | inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
348 | prediction = inference()
349 | ```
350 |
351 | ## 🔬 Powered by research
352 |
353 | Diart is the official implementation of the paper
354 | [Overlap-aware low-latency online speaker diarization based on end-to-end local segmentation](https://github.com/juanmc2005/diart/blob/main/paper.pdf)
355 | by [Juan Manuel Coria](https://juanmc2005.github.io/),
356 | [Hervé Bredin](https://herve.niderb.fr),
357 | [Sahar Ghannay](https://saharghannay.github.io/)
358 | and [Sophie Rosset](https://perso.limsi.fr/rosset/).
359 |
360 |
361 | > We propose to address online speaker diarization as a combination of incremental clustering and local diarization applied to a rolling buffer updated every 500ms. Every single step of the proposed pipeline is designed to take full advantage of the strong ability of a recently proposed end-to-end overlap-aware segmentation to detect and separate overlapping speakers. In particular, we propose a modified version of the statistics pooling layer (initially introduced in the x-vector architecture) to give less weight to frames where the segmentation model predicts simultaneous speakers. Furthermore, we derive cannot-link constraints from the initial segmentation step to prevent two local speakers from being wrongfully merged during the incremental clustering step. Finally, we show how the latency of the proposed approach can be adjusted between 500ms and 5s to match the requirements of a particular use case, and we provide a systematic analysis of the influence of latency on the overall performance (on AMI, DIHARD and VoxConverse).
362 |
363 |
364 |
365 |
366 |
367 | ### Citation
368 |
369 | If you found diart useful, please make sure to cite our paper:
370 |
371 | ```bibtex
372 | @inproceedings{diart,
373 | author={Coria, Juan M. and Bredin, Hervé and Ghannay, Sahar and Rosset, Sophie},
374 | booktitle={2021 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU)},
375 | title={Overlap-Aware Low-Latency Online Speaker Diarization Based on End-to-End Local Segmentation},
376 | year={2021},
377 | pages={1139-1146},
378 | doi={10.1109/ASRU51503.2021.9688044},
379 | }
380 | ```
381 |
382 | ### Reproducibility
383 |
384 | 
385 |
386 | **Important:** We highly recommend installing `pyannote.audio<3.1` to reproduce these results.
387 | For more information, see [this issue](https://github.com/juanmc2005/diart/issues/214).
388 |
389 | Diart aims to be lightweight and capable of real-time streaming in practical scenarios.
390 | Its performance is very close to what is reported in the paper (and sometimes even a bit better).
391 |
392 | To obtain the best results, make sure to use the following hyper-parameters:
393 |
394 | | Dataset | latency | tau | rho | delta |
395 | |-------------|---------|--------|--------|-------|
396 | | DIHARD III | any | 0.555 | 0.422 | 1.517 |
397 | | AMI | any | 0.507 | 0.006 | 1.057 |
398 | | VoxConverse | any | 0.576 | 0.915 | 0.648 |
399 | | DIHARD II | 1s | 0.619 | 0.326 | 0.997 |
400 | | DIHARD II | 5s | 0.555 | 0.422 | 1.517 |
401 |
402 | `diart.benchmark` and `diart.inference.Benchmark` can run, evaluate and measure the real-time latency of the pipeline. For instance, for a DIHARD III configuration:
403 |
404 | ```shell
405 | diart.benchmark /wav/dir --reference /rttm/dir --tau-active=0.555 --rho-update=0.422 --delta-new=1.517 --segmentation pyannote/segmentation@Interspeech2021
406 | ```
407 |
408 | or using the inference API:
409 |
410 | ```python
411 | from diart.inference import Benchmark, Parallelize
412 | from diart import SpeakerDiarization, SpeakerDiarizationConfig
413 | from diart.models import SegmentationModel
414 |
415 | benchmark = Benchmark("/wav/dir", "/rttm/dir")
416 |
417 | model_name = "pyannote/segmentation@Interspeech2021"
418 | model = SegmentationModel.from_pretrained(model_name)
419 | config = SpeakerDiarizationConfig(
420 | # Set the segmentation model used in the paper
421 | segmentation=model,
422 | step=0.5,
423 | latency=0.5,
424 | tau_active=0.555,
425 | rho_update=0.422,
426 | delta_new=1.517
427 | )
428 | benchmark(SpeakerDiarization, config)
429 |
430 | # Run the same benchmark in parallel
431 | p_benchmark = Parallelize(benchmark, num_workers=4)
432 | if __name__ == "__main__": # Needed for multiprocessing
433 | p_benchmark(SpeakerDiarization, config)
434 | ```
435 |
436 | This pre-calculates model outputs in batches, so it runs a lot faster.
437 | See `diart.benchmark -h` for more options.
438 |
439 | For convenience and to facilitate future comparisons, we also provide the
440 | expected outputs
441 | of the paper implementation in RTTM format for every entry of Table 1 and Figure 5.
442 | This includes the VBx offline topline as well as our proposed online approach with
443 | latencies 500ms, 1s, 2s, 3s, 4s, and 5s.
444 |
445 | 
446 |
447 | ## 📑 License
448 |
449 | ```
450 | MIT License
451 |
452 | Copyright (c) 2021 Université Paris-Saclay
453 | Copyright (c) 2021 CNRS
454 |
455 | Permission is hereby granted, free of charge, to any person obtaining a copy
456 | of this software and associated documentation files (the "Software"), to deal
457 | in the Software without restriction, including without limitation the rights
458 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
459 | copies of the Software, and to permit persons to whom the Software is
460 | furnished to do so, subject to the following conditions:
461 |
462 | The above copyright notice and this permission notice shall be included in all
463 | copies or substantial portions of the Software.
464 |
465 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
466 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
467 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
468 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
469 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
470 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
471 | SOFTWARE.
472 | ```
473 |
474 | Logo generated by DesignEvo free logo designer
475 |
--------------------------------------------------------------------------------
/assets/models/embedding_uint8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/assets/models/embedding_uint8.onnx
--------------------------------------------------------------------------------
/assets/models/segmentation_uint8.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/assets/models/segmentation_uint8.onnx
--------------------------------------------------------------------------------
/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/demo.gif
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/docs/_static/logo.png
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 |
9 | project = "diart"
10 | copyright = "2023, Juan Manuel Coria"
11 | author = "Juan Manuel Coria"
12 | release = "v0.9"
13 |
14 | # -- General configuration ---------------------------------------------------
15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
16 |
17 | extensions = [
18 | "autoapi.extension",
19 | "sphinx.ext.coverage",
20 | "sphinx.ext.napoleon",
21 | "sphinx_mdinclude",
22 | ]
23 |
24 | autoapi_dirs = ["../src/diart"]
25 | autoapi_options = [
26 | "members",
27 | "undoc-members",
28 | "show-inheritance",
29 | "show-module-summary",
30 | "special-members",
31 | "imported-members",
32 | ]
33 |
34 | templates_path = ["_templates"]
35 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
36 |
37 | # -- Options for autodoc ----------------------------------------------------
38 | # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration
39 |
40 | # Automatically extract typehints when specified and place them in
41 | # descriptions of the relevant function/method.
42 | autodoc_typehints = "description"
43 |
44 | # Don't show class signature with the class' name.
45 | autodoc_class_signature = "separated"
46 |
47 | # -- Options for HTML output -------------------------------------------------
48 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
49 |
50 | html_theme = "furo"
51 | html_static_path = ["_static"]
52 | html_logo = "_static/logo.png"
53 | html_title = "diart documentation"
54 |
55 |
56 | def skip_submodules(app, what, name, obj, skip, options):
57 | return (
58 | name.endswith("__init__")
59 | or name.startswith("diart.console")
60 | or name.startswith("diart.argdoc")
61 | )
62 |
63 |
64 | def setup(sphinx):
65 | sphinx.connect("autoapi-skip-member", skip_submodules)
66 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Get started with diart
2 | ======================
3 |
4 | .. mdinclude:: ../README.md
5 |
6 |
7 | Useful Links
8 | ============
9 |
10 | .. toctree::
11 | :maxdepth: 1
12 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==6.2.1
2 | sphinx-autoapi==3.0.0
3 | sphinx-mdinclude==0.5.3
4 | furo==2023.9.10
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: diart
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.10
7 | - portaudio=19.6.*
8 | - pysoundfile=0.12.*
9 | - ffmpeg[version='<4.4']
10 | - pip
11 | - pip:
12 | - .
--------------------------------------------------------------------------------
/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/figure1.png
--------------------------------------------------------------------------------
/figure5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/figure5.png
--------------------------------------------------------------------------------
/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/logo.jpg
--------------------------------------------------------------------------------
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/paper.pdf
--------------------------------------------------------------------------------
/pipeline.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/pipeline.gif
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 40.9.0",
4 | "wheel",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.20.2,<2.0.0
2 | matplotlib>=3.3.3,<4.0.0
3 | rx>=3.2.0
4 | scipy>=1.6.0
5 | sounddevice>=0.4.2
6 | einops>=0.3.0
7 | tqdm>=4.64.0
8 | pandas>=1.4.2
9 | torch>=1.12.1
10 | torchvision>=0.14.0
11 | torchaudio>=2.0.2
12 | pyannote.audio>=2.1.1
13 | requests>=2.31.0
14 | pyannote.core>=4.5
15 | pyannote.database>=4.1.1
16 | pyannote.metrics>=3.2
17 | optuna>=2.10
18 | websocket-server>=0.6.4
19 | websocket-client>=0.58.0
20 | rich>=12.5.1
21 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name=diart
3 | version=0.9.2
4 | author=Juan Manuel Coria
5 | description=A python framework to build AI for real-time speech
6 | long_description=file: README.md
7 | long_description_content_type=text/markdown
8 | keywords=speaker diarization, streaming, online, real time, rxpy
9 | url=https://github.com/juanmc2005/diart
10 | license=MIT
11 | classifiers=
12 | Development Status :: 4 - Beta
13 | License :: OSI Approved :: MIT License
14 | Topic :: Multimedia :: Sound/Audio :: Analysis
15 | Topic :: Multimedia :: Sound/Audio :: Speech
16 | Topic :: Scientific/Engineering :: Artificial Intelligence
17 |
18 | [options]
19 | package_dir=
20 | =src
21 | packages=find:
22 | install_requires=
23 | numpy>=1.20.2,<2.0.0
24 | matplotlib>=3.3.3,<4.0.0
25 | rx>=3.2.0
26 | scipy>=1.6.0
27 | sounddevice>=0.4.2
28 | einops>=0.3.0
29 | tqdm>=4.64.0
30 | pandas>=1.4.2
31 | torch>=1.12.1
32 | torchvision>=0.14.0
33 | torchaudio>=2.0.2
34 | pyannote.audio>=2.1.1
35 | requests>=2.31.0
36 | pyannote.core>=4.5
37 | pyannote.database>=4.1.1
38 | pyannote.metrics>=3.2
39 | optuna>=2.10
40 | websocket-server>=0.6.4
41 | websocket-client>=0.58.0
42 | rich>=12.5.1
43 |
44 | [options.extras_require]
45 | tests=
46 | pytest>=7.4.0,<8.0.0
47 | onnxruntime==1.18.0
48 |
49 | [options.packages.find]
50 | where=src
51 |
52 | [options.entry_points]
53 | console_scripts=
54 | diart.stream=diart.console.stream:run
55 | diart.benchmark=diart.console.benchmark:run
56 | diart.tune=diart.console.tune:run
57 | diart.serve=diart.console.serve:run
58 | diart.client=diart.console.client:run
59 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | setuptools.setup()
3 |
--------------------------------------------------------------------------------
/src/diart/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import (
2 | SpeakerDiarization,
3 | Pipeline,
4 | SpeakerDiarizationConfig,
5 | PipelineConfig,
6 | VoiceActivityDetection,
7 | VoiceActivityDetectionConfig,
8 | )
9 |
--------------------------------------------------------------------------------
/src/diart/argdoc.py:
--------------------------------------------------------------------------------
1 | SEGMENTATION = "Segmentation model name from pyannote"
2 | EMBEDDING = "Embedding model name from pyannote"
3 | DURATION = "Chunk duration (in seconds)"
4 | STEP = "Sliding window step (in seconds)"
5 | LATENCY = "System latency (in seconds). STEP <= LATENCY <= CHUNK_DURATION"
6 | TAU = "Probability threshold to consider a speaker as active. 0 <= TAU <= 1"
7 | RHO = "Speech ratio threshold to decide if centroids are updated with a given speaker. 0 <= RHO <= 1"
8 | DELTA = "Embedding-to-centroid distance threshold to flag a speaker as known or new. 0 <= DELTA <= 2"
9 | GAMMA = "Parameter gamma for overlapped speech penalty"
10 | BETA = "Parameter beta for overlapped speech penalty"
11 | MAX_SPEAKERS = "Maximum number of speakers"
12 | CPU = "Force models to run on CPU"
13 | BATCH_SIZE = "For segmentation and embedding pre-calculation. If BATCH_SIZE < 2, run fully online and estimate real-time latency"
14 | NUM_WORKERS = "Number of parallel workers"
15 | OUTPUT = "Directory to store the system's output in RTTM format"
16 | HF_TOKEN = "Huggingface authentication token for hosted models ('true' | 'false' | ). If 'true', it will use the token from huggingface-cli login"
17 | SAMPLE_RATE = "Sample rate of the audio stream"
18 | NORMALIZE_EMBEDDING_WEIGHTS = "Rescale embedding weights (min-max normalization) to be in the range [0, 1]. This is useful in some models without weighted statistics pooling that rely on masking, like Nvidia's NeMo or ECAPA-TDNN"
19 |
--------------------------------------------------------------------------------
/src/diart/audio.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Text, Union
3 |
4 | import torch
5 | import torchaudio
6 | from torchaudio.functional import resample
7 |
8 | torchaudio.set_audio_backend("soundfile")
9 |
10 |
11 | FilePath = Union[Text, Path]
12 |
13 |
14 | class AudioLoader:
15 | def __init__(self, sample_rate: int, mono: bool = True):
16 | self.sample_rate = sample_rate
17 | self.mono = mono
18 |
19 | def load(self, filepath: FilePath) -> torch.Tensor:
20 | """Load an audio file into a torch.Tensor.
21 |
22 | Parameters
23 | ----------
24 | filepath : FilePath
25 | Path to an audio file
26 |
27 | Returns
28 | -------
29 | waveform : torch.Tensor, shape (channels, samples)
30 | """
31 | waveform, sample_rate = torchaudio.load(filepath)
32 | # Get channel mean if mono
33 | if self.mono and waveform.shape[0] > 1:
34 | waveform = waveform.mean(dim=0, keepdim=True)
35 | # Resample if needed
36 | if self.sample_rate != sample_rate:
37 | waveform = resample(waveform, sample_rate, self.sample_rate)
38 | return waveform
39 |
40 | @staticmethod
41 | def get_duration(filepath: FilePath) -> float:
42 | """Get audio file duration in seconds.
43 |
44 | Parameters
45 | ----------
46 | filepath : FilePath
47 | Path to an audio file.
48 |
49 | Returns
50 | -------
51 | duration : float
52 | Duration in seconds.
53 | """
54 | info = torchaudio.info(filepath)
55 | return info.num_frames / info.sample_rate
56 |
--------------------------------------------------------------------------------
/src/diart/blocks/__init__.py:
--------------------------------------------------------------------------------
1 | from .aggregation import (
2 | AggregationStrategy,
3 | HammingWeightedAverageStrategy,
4 | AverageStrategy,
5 | FirstOnlyStrategy,
6 | DelayedAggregation,
7 | )
8 | from .clustering import OnlineSpeakerClustering
9 | from .embedding import (
10 | SpeakerEmbedding,
11 | OverlappedSpeechPenalty,
12 | EmbeddingNormalization,
13 | OverlapAwareSpeakerEmbedding,
14 | )
15 | from .segmentation import SpeakerSegmentation
16 | from .diarization import SpeakerDiarization, SpeakerDiarizationConfig
17 | from .base import PipelineConfig, Pipeline
18 | from .utils import Binarize, Resample, AdjustVolume
19 | from .vad import VoiceActivityDetection, VoiceActivityDetectionConfig
20 |
--------------------------------------------------------------------------------
/src/diart/blocks/aggregation.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, List
3 |
4 | import numpy as np
5 | from pyannote.core import Segment, SlidingWindow, SlidingWindowFeature
6 | from typing_extensions import Literal
7 |
8 |
9 | class AggregationStrategy(ABC):
10 | """Abstract class representing a strategy to aggregate overlapping buffers
11 |
12 | Parameters
13 | ----------
14 | cropping_mode: ("strict", "loose", "center"), optional
15 | Defines the mode to crop buffer chunks as in pyannote.core.
16 | See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop
17 | Defaults to "loose".
18 | """
19 |
20 | def __init__(self, cropping_mode: Literal["strict", "loose", "center"] = "loose"):
21 | assert cropping_mode in [
22 | "strict",
23 | "loose",
24 | "center",
25 | ], f"Invalid cropping mode `{cropping_mode}`"
26 | self.cropping_mode = cropping_mode
27 |
28 | @staticmethod
29 | def build(
30 | name: Literal["mean", "hamming", "first"],
31 | cropping_mode: Literal["strict", "loose", "center"] = "loose",
32 | ) -> "AggregationStrategy":
33 | """Build an AggregationStrategy instance based on its name"""
34 | assert name in ("mean", "hamming", "first")
35 | if name == "mean":
36 | return AverageStrategy(cropping_mode)
37 | elif name == "hamming":
38 | return HammingWeightedAverageStrategy(cropping_mode)
39 | else:
40 | return FirstOnlyStrategy(cropping_mode)
41 |
42 | def __call__(
43 | self, buffers: List[SlidingWindowFeature], focus: Segment
44 | ) -> SlidingWindowFeature:
45 | """Aggregate chunks over a specific region.
46 |
47 | Parameters
48 | ----------
49 | buffers: list of SlidingWindowFeature, shapes (frames, speakers)
50 | Buffers to aggregate
51 | focus: Segment
52 | Region to aggregate that is shared among the buffers
53 |
54 | Returns
55 | -------
56 | aggregation: SlidingWindowFeature, shape (cropped_frames, speakers)
57 | Aggregated values over the focus region
58 | """
59 | aggregation = self.aggregate(buffers, focus)
60 | resolution = focus.duration / aggregation.shape[0]
61 | resolution = SlidingWindow(
62 | start=focus.start, duration=resolution, step=resolution
63 | )
64 | return SlidingWindowFeature(aggregation, resolution)
65 |
66 | @abstractmethod
67 | def aggregate(
68 | self, buffers: List[SlidingWindowFeature], focus: Segment
69 | ) -> np.ndarray:
70 | pass
71 |
72 |
73 | class HammingWeightedAverageStrategy(AggregationStrategy):
74 | """Compute the average weighted by the corresponding Hamming-window aligned to each buffer"""
75 |
76 | def aggregate(
77 | self, buffers: List[SlidingWindowFeature], focus: Segment
78 | ) -> np.ndarray:
79 | num_frames, num_speakers = buffers[0].data.shape
80 | hamming, intersection = [], []
81 | for buffer in buffers:
82 | # Crop buffer to focus region
83 | b = buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
84 | # Crop Hamming window to focus region
85 | h = np.expand_dims(np.hamming(num_frames), axis=-1)
86 | h = SlidingWindowFeature(h, buffer.sliding_window)
87 | h = h.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
88 | hamming.append(h.data)
89 | intersection.append(b.data)
90 | hamming, intersection = np.stack(hamming), np.stack(intersection)
91 | # Calculate weighted mean
92 | return np.sum(hamming * intersection, axis=0) / np.sum(hamming, axis=0)
93 |
94 |
95 | class AverageStrategy(AggregationStrategy):
96 | """Compute a simple average over the focus region"""
97 |
98 | def aggregate(
99 | self, buffers: List[SlidingWindowFeature], focus: Segment
100 | ) -> np.ndarray:
101 | # Stack all overlapping regions
102 | intersection = np.stack(
103 | [
104 | buffer.crop(focus, mode=self.cropping_mode, fixed=focus.duration)
105 | for buffer in buffers
106 | ]
107 | )
108 | return np.mean(intersection, axis=0)
109 |
110 |
111 | class FirstOnlyStrategy(AggregationStrategy):
112 | """Instead of aggregating, keep the first focus region in the buffer list"""
113 |
114 | def aggregate(
115 | self, buffers: List[SlidingWindowFeature], focus: Segment
116 | ) -> np.ndarray:
117 | return buffers[0].crop(focus, mode=self.cropping_mode, fixed=focus.duration)
118 |
119 |
120 | class DelayedAggregation:
121 | """Aggregate aligned overlapping windows of the same duration
122 | across sliding buffers with a specific step and latency.
123 |
124 | Parameters
125 | ----------
126 | step: float
127 | Shift between two consecutive buffers, in seconds.
128 | latency: float, optional
129 | Desired latency, in seconds. Defaults to step.
130 | The higher the latency, the more overlapping windows to aggregate.
131 | strategy: ("mean", "hamming", "first"), optional
132 | Specifies how to aggregate overlapping windows. Defaults to "hamming".
133 | "mean": simple average
134 | "hamming": average weighted by the Hamming window values (aligned to the buffer)
135 | "first": no aggregation, pick the first overlapping window
136 | cropping_mode: ("strict", "loose", "center"), optional
137 | Defines the mode to crop buffer chunks as in pyannote.core.
138 | See https://pyannote.github.io/pyannote-core/reference.html#pyannote.core.SlidingWindowFeature.crop
139 | Defaults to "loose".
140 |
141 | Example
142 | --------
143 | >>> duration = 5
144 | >>> frames = 500
145 | >>> step = 0.5
146 | >>> speakers = 2
147 | >>> start_time = 10
148 | >>> resolution = duration / frames
149 | >>> dagg = DelayedAggregation(step=step, latency=2, strategy="mean")
150 | >>> buffers = [
151 | >>> SlidingWindowFeature(
152 | >>> np.random.rand(frames, speakers),
153 | >>> SlidingWindow(start=(i + start_time) * step, duration=resolution, step=resolution)
154 | >>> )
155 | >>> for i in range(dagg.num_overlapping_windows)
156 | >>> ]
157 | >>> dagg.num_overlapping_windows
158 | ... 4
159 | >>> dagg(buffers).data.shape
160 | ... (51, 2) # Rounding errors are possible when cropping the buffers
161 | """
162 |
163 | def __init__(
164 | self,
165 | step: float,
166 | latency: Optional[float] = None,
167 | strategy: Literal["mean", "hamming", "first"] = "hamming",
168 | cropping_mode: Literal["strict", "loose", "center"] = "loose",
169 | ):
170 | self.step = step
171 | self.latency = latency
172 | self.strategy = strategy
173 | assert cropping_mode in [
174 | "strict",
175 | "loose",
176 | "center",
177 | ], f"Invalid cropping mode `{cropping_mode}`"
178 | self.cropping_mode = cropping_mode
179 |
180 | if self.latency is None:
181 | self.latency = self.step
182 |
183 | assert self.step <= self.latency, "Invalid latency requested"
184 |
185 | self.num_overlapping_windows = int(round(self.latency / self.step))
186 | self.aggregate = AggregationStrategy.build(self.strategy, self.cropping_mode)
187 |
188 | def _prepend(
189 | self,
190 | output_window: SlidingWindowFeature,
191 | output_region: Segment,
192 | buffers: List[SlidingWindowFeature],
193 | ):
194 | # FIXME instead of prepending the output of the first chunk,
195 | # add padding of `chunk_duration - latency` seconds at the
196 | # beginning of the stream so scores can be aggregated accordingly.
197 | # Remember to shift predictions by the padding.
198 | last_buffer = buffers[-1].extent
199 | # Prepend prediction until we match the latency in case of first buffer
200 | if len(buffers) == 1 and last_buffer.start == 0:
201 | num_frames = output_window.data.shape[0]
202 | first_region = Segment(0, output_region.end)
203 | first_output = buffers[0].crop(
204 | first_region, mode=self.cropping_mode, fixed=first_region.duration
205 | )
206 | first_output[-num_frames:] = output_window.data
207 | resolution = output_region.end / first_output.shape[0]
208 | output_window = SlidingWindowFeature(
209 | first_output,
210 | SlidingWindow(start=0, duration=resolution, step=resolution),
211 | )
212 | return output_window
213 |
214 | def __call__(self, buffers: List[SlidingWindowFeature]) -> SlidingWindowFeature:
215 | # Determine overlapping region to aggregate
216 | start = buffers[-1].extent.end - self.latency
217 | region = Segment(start, start + self.step)
218 | return self._prepend(self.aggregate(buffers, region), region, buffers)
219 |
--------------------------------------------------------------------------------
/src/diart/blocks/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Any, Tuple, Sequence, Text
4 |
5 | from pyannote.core import SlidingWindowFeature
6 | from pyannote.metrics.base import BaseMetric
7 |
8 | from .. import utils
9 | from ..audio import FilePath, AudioLoader
10 |
11 |
12 | @dataclass
13 | class HyperParameter:
14 | """Represents a pipeline hyper-parameter that can be tuned by diart"""
15 |
16 | name: Text
17 | """Name of the hyper-parameter (e.g. tau_active)"""
18 | low: float
19 | """Lowest value that this parameter can take"""
20 | high: float
21 | """Highest value that this parameter can take"""
22 |
23 | @staticmethod
24 | def from_name(name: Text) -> "HyperParameter":
25 | """Create a HyperParameter object given its name.
26 |
27 | Parameters
28 | ----------
29 | name: str
30 | Name of the hyper-parameter
31 |
32 | Returns
33 | -------
34 | HyperParameter
35 | """
36 | if name == "tau_active":
37 | return TauActive
38 | if name == "rho_update":
39 | return RhoUpdate
40 | if name == "delta_new":
41 | return DeltaNew
42 | raise ValueError(f"Hyper-parameter '{name}' not recognized")
43 |
44 |
45 | TauActive = HyperParameter("tau_active", low=0, high=1)
46 | RhoUpdate = HyperParameter("rho_update", low=0, high=1)
47 | DeltaNew = HyperParameter("delta_new", low=0, high=2)
48 |
49 |
50 | class PipelineConfig(ABC):
51 | """Configuration containing the required
52 | parameters to build and run a pipeline"""
53 |
54 | @property
55 | @abstractmethod
56 | def duration(self) -> float:
57 | """The duration of an input audio chunk (in seconds)"""
58 | pass
59 |
60 | @property
61 | @abstractmethod
62 | def step(self) -> float:
63 | """The step between two consecutive input audio chunks (in seconds)"""
64 | pass
65 |
66 | @property
67 | @abstractmethod
68 | def latency(self) -> float:
69 | """The algorithmic latency of the pipeline (in seconds).
70 | At time `t` of the audio stream, the pipeline will
71 | output predictions for time `t - latency`.
72 | """
73 | pass
74 |
75 | @property
76 | @abstractmethod
77 | def sample_rate(self) -> int:
78 | """The sample rate of the input audio stream"""
79 | pass
80 |
81 | def get_file_padding(self, filepath: FilePath) -> Tuple[float, float]:
82 | file_duration = AudioLoader(self.sample_rate, mono=True).get_duration(filepath)
83 | right = utils.get_padding_right(self.latency, self.step)
84 | left = utils.get_padding_left(file_duration + right, self.duration)
85 | return left, right
86 |
87 |
88 | class Pipeline(ABC):
89 | """Represents a streaming audio pipeline"""
90 |
91 | @staticmethod
92 | @abstractmethod
93 | def get_config_class() -> type:
94 | pass
95 |
96 | @staticmethod
97 | @abstractmethod
98 | def suggest_metric() -> BaseMetric:
99 | pass
100 |
101 | @staticmethod
102 | @abstractmethod
103 | def hyper_parameters() -> Sequence[HyperParameter]:
104 | pass
105 |
106 | @property
107 | @abstractmethod
108 | def config(self) -> PipelineConfig:
109 | pass
110 |
111 | @abstractmethod
112 | def reset(self):
113 | pass
114 |
115 | @abstractmethod
116 | def set_timestamp_shift(self, shift: float):
117 | pass
118 |
119 | @abstractmethod
120 | def __call__(
121 | self, waveforms: Sequence[SlidingWindowFeature]
122 | ) -> Sequence[Tuple[Any, SlidingWindowFeature]]:
123 | """Runs the next steps of the pipeline
124 | given a list of consecutive audio chunks.
125 |
126 | Parameters
127 | ----------
128 | waveforms: Sequence[SlidingWindowFeature]
129 | Consecutive chunk waveforms for the pipeline to ingest
130 |
131 | Returns
132 | -------
133 | Sequence[Tuple[Any, SlidingWindowFeature]]
134 | For each input waveform, a tuple containing
135 | the pipeline output and its respective audio
136 | """
137 | pass
138 |
--------------------------------------------------------------------------------
/src/diart/blocks/clustering.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Iterable, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from pyannote.core import SlidingWindowFeature
6 |
7 | from ..mapping import SpeakerMap, SpeakerMapBuilder
8 |
9 |
10 | class OnlineSpeakerClustering:
11 | """Implements constrained incremental online clustering of speakers and manages cluster centers.
12 |
13 | Parameters
14 | ----------
15 | tau_active:float
16 | Threshold for detecting active speakers. This threshold is applied on the maximum value of per-speaker output
17 | activation of the local segmentation model.
18 | rho_update: float
19 | Threshold for considering the extracted embedding when updating the centroid of the local speaker.
20 | The centroid to which a local speaker is mapped is only updated if the ratio of speech/chunk duration
21 | of a given local speaker is greater than this threshold.
22 | delta_new: float
23 | Threshold on the distance between a speaker embedding and a centroid. If the distance between a local speaker and all
24 | centroids is larger than delta_new, then a new centroid is created for the current speaker.
25 | metric: str. Defaults to "cosine".
26 | The distance metric to use.
27 | max_speakers: int
28 | Maximum number of global speakers to track through a conversation. Defaults to 20.
29 | """
30 |
31 | def __init__(
32 | self,
33 | tau_active: float,
34 | rho_update: float,
35 | delta_new: float,
36 | metric: Optional[str] = "cosine",
37 | max_speakers: int = 20,
38 | ):
39 | self.tau_active = tau_active
40 | self.rho_update = rho_update
41 | self.delta_new = delta_new
42 | self.metric = metric
43 | self.max_speakers = max_speakers
44 | self.centers: Optional[np.ndarray] = None
45 | self.active_centers = set()
46 | self.blocked_centers = set()
47 |
48 | @property
49 | def num_free_centers(self) -> int:
50 | return self.max_speakers - self.num_known_speakers - self.num_blocked_speakers
51 |
52 | @property
53 | def num_known_speakers(self) -> int:
54 | return len(self.active_centers)
55 |
56 | @property
57 | def num_blocked_speakers(self) -> int:
58 | return len(self.blocked_centers)
59 |
60 | @property
61 | def inactive_centers(self) -> List[int]:
62 | return [
63 | c
64 | for c in range(self.max_speakers)
65 | if c not in self.active_centers or c in self.blocked_centers
66 | ]
67 |
68 | def get_next_center_position(self) -> Optional[int]:
69 | for center in range(self.max_speakers):
70 | if center not in self.active_centers and center not in self.blocked_centers:
71 | return center
72 |
73 | def init_centers(self, dimension: int):
74 | """Initializes the speaker centroid matrix
75 |
76 | Parameters
77 | ----------
78 | dimension: int
79 | Dimension of embeddings used for representing a speaker.
80 | """
81 | self.centers = np.zeros((self.max_speakers, dimension))
82 | self.active_centers = set()
83 | self.blocked_centers = set()
84 |
85 | def update(self, assignments: Iterable[Tuple[int, int]], embeddings: np.ndarray):
86 | """Updates the speaker centroids given a list of assignments and local speaker embeddings
87 |
88 | Parameters
89 | ----------
90 | assignments: Iterable[Tuple[int, int]])
91 | An iterable of tuples with two elements having the first element as the source speaker
92 | and the second element as the target speaker.
93 | embeddings: np.ndarray, shape (local_speakers, embedding_dim)
94 | Matrix containing embeddings for all local speakers.
95 | """
96 | if self.centers is not None:
97 | for l_spk, g_spk in assignments:
98 | assert g_spk in self.active_centers, "Cannot update unknown centers"
99 | self.centers[g_spk] += embeddings[l_spk]
100 |
101 | def add_center(self, embedding: np.ndarray) -> int:
102 | """Add a new speaker centroid initialized to a given embedding
103 |
104 | Parameters
105 | ----------
106 | embedding: np.ndarray
107 | Embedding vector of some local speaker
108 |
109 | Returns
110 | -------
111 | center_index: int
112 | Index of the created center
113 | """
114 | center = self.get_next_center_position()
115 | self.centers[center] = embedding
116 | self.active_centers.add(center)
117 | return center
118 |
119 | def identify(
120 | self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
121 | ) -> SpeakerMap:
122 | """Identify the centroids to which the input speaker embeddings belong.
123 |
124 | Parameters
125 | ----------
126 | segmentation: np.ndarray, shape (frames, local_speakers)
127 | Matrix of segmentation outputs
128 | embeddings: np.ndarray, shape (local_speakers, embedding_dim)
129 | Matrix of embeddings
130 |
131 | Returns
132 | -------
133 | speaker_map: SpeakerMap
134 | A mapping from local speakers to global speakers.
135 | """
136 | embeddings = embeddings.detach().cpu().numpy()
137 | active_speakers = np.where(
138 | np.max(segmentation.data, axis=0) >= self.tau_active
139 | )[0]
140 | long_speakers = np.where(np.mean(segmentation.data, axis=0) >= self.rho_update)[
141 | 0
142 | ]
143 | # Remove speakers that have NaN embeddings
144 | no_nan_embeddings = np.where(~np.isnan(embeddings).any(axis=1))[0]
145 | active_speakers = np.intersect1d(active_speakers, no_nan_embeddings)
146 |
147 | num_local_speakers = segmentation.data.shape[1]
148 |
149 | if self.centers is None:
150 | self.init_centers(embeddings.shape[1])
151 | assignments = [
152 | (spk, self.add_center(embeddings[spk])) for spk in active_speakers
153 | ]
154 | return SpeakerMapBuilder.hard_map(
155 | shape=(num_local_speakers, self.max_speakers),
156 | assignments=assignments,
157 | maximize=False,
158 | )
159 |
160 | # Obtain a mapping based on distances between embeddings and centers
161 | dist_map = SpeakerMapBuilder.dist(embeddings, self.centers, self.metric)
162 | # Remove any assignments containing invalid speakers
163 | inactive_speakers = np.array(
164 | [spk for spk in range(num_local_speakers) if spk not in active_speakers]
165 | )
166 | dist_map = dist_map.unmap_speakers(inactive_speakers, self.inactive_centers)
167 | # Keep assignments under the distance threshold
168 | valid_map = dist_map.unmap_threshold(self.delta_new)
169 |
170 | # Some speakers might be unidentified
171 | missed_speakers = [
172 | s for s in active_speakers if not valid_map.is_source_speaker_mapped(s)
173 | ]
174 |
175 | # Add assignments to new centers if possible
176 | new_center_speakers = []
177 | for spk in missed_speakers:
178 | has_space = len(new_center_speakers) < self.num_free_centers
179 | if has_space and spk in long_speakers:
180 | # Flag as a new center
181 | new_center_speakers.append(spk)
182 | else:
183 | # Cannot create a new center
184 | # Get global speakers in order of preference
185 | preferences = np.argsort(dist_map.mapping_matrix[spk, :])
186 | preferences = [
187 | g_spk for g_spk in preferences if g_spk in self.active_centers
188 | ]
189 | # Get the free global speakers among the preferences
190 | _, g_assigned = valid_map.valid_assignments()
191 | free = [g_spk for g_spk in preferences if g_spk not in g_assigned]
192 | if free:
193 | # The best global speaker is the closest free one
194 | valid_map = valid_map.set_source_speaker(spk, free[0])
195 |
196 | # Update known centers
197 | to_update = [
198 | (ls, gs)
199 | for ls, gs in zip(*valid_map.valid_assignments())
200 | if ls not in missed_speakers and ls in long_speakers
201 | ]
202 | self.update(to_update, embeddings)
203 |
204 | # Add new centers
205 | for spk in new_center_speakers:
206 | valid_map = valid_map.set_source_speaker(
207 | spk, self.add_center(embeddings[spk])
208 | )
209 |
210 | return valid_map
211 |
212 | def __call__(
213 | self, segmentation: SlidingWindowFeature, embeddings: torch.Tensor
214 | ) -> SlidingWindowFeature:
215 | return SlidingWindowFeature(
216 | self.identify(segmentation, embeddings).apply(segmentation.data),
217 | segmentation.sliding_window,
218 | )
219 |
--------------------------------------------------------------------------------
/src/diart/blocks/diarization.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Sequence
4 |
5 | import numpy as np
6 | import torch
7 | from pyannote.core import Annotation, SlidingWindowFeature, SlidingWindow, Segment
8 | from pyannote.metrics.base import BaseMetric
9 | from pyannote.metrics.diarization import DiarizationErrorRate
10 | from typing_extensions import Literal
11 |
12 | from . import base
13 | from .aggregation import DelayedAggregation
14 | from .clustering import OnlineSpeakerClustering
15 | from .embedding import OverlapAwareSpeakerEmbedding
16 | from .segmentation import SpeakerSegmentation
17 | from .utils import Binarize
18 | from .. import models as m
19 |
20 |
21 | class SpeakerDiarizationConfig(base.PipelineConfig):
22 | def __init__(
23 | self,
24 | segmentation: m.SegmentationModel | None = None,
25 | embedding: m.EmbeddingModel | None = None,
26 | duration: float = 5,
27 | step: float = 0.5,
28 | latency: float | Literal["max", "min"] | None = None,
29 | tau_active: float = 0.6,
30 | rho_update: float = 0.3,
31 | delta_new: float = 1,
32 | gamma: float = 3,
33 | beta: float = 10,
34 | max_speakers: int = 20,
35 | normalize_embedding_weights: bool = False,
36 | device: torch.device | None = None,
37 | sample_rate: int = 16000,
38 | **kwargs,
39 | ):
40 | # Default segmentation model is pyannote/segmentation
41 | self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
42 | "pyannote/segmentation"
43 | )
44 |
45 | # Default embedding model is pyannote/embedding
46 | self.embedding = embedding or m.EmbeddingModel.from_pyannote(
47 | "pyannote/embedding"
48 | )
49 |
50 | self._duration = duration
51 | self._sample_rate = sample_rate
52 |
53 | # Latency defaults to the step duration
54 | self._step = step
55 | self._latency = latency
56 | if self._latency is None or self._latency == "min":
57 | self._latency = self._step
58 | elif self._latency == "max":
59 | self._latency = self._duration
60 |
61 | self.tau_active = tau_active
62 | self.rho_update = rho_update
63 | self.delta_new = delta_new
64 | self.gamma = gamma
65 | self.beta = beta
66 | self.max_speakers = max_speakers
67 | self.normalize_embedding_weights = normalize_embedding_weights
68 | self.device = device or torch.device(
69 | "cuda" if torch.cuda.is_available() else "cpu"
70 | )
71 |
72 | @property
73 | def duration(self) -> float:
74 | return self._duration
75 |
76 | @property
77 | def step(self) -> float:
78 | return self._step
79 |
80 | @property
81 | def latency(self) -> float:
82 | return self._latency
83 |
84 | @property
85 | def sample_rate(self) -> int:
86 | return self._sample_rate
87 |
88 |
89 | class SpeakerDiarization(base.Pipeline):
90 | def __init__(self, config: SpeakerDiarizationConfig | None = None):
91 | self._config = SpeakerDiarizationConfig() if config is None else config
92 |
93 | msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
94 | assert self._config.step <= self._config.latency <= self._config.duration, msg
95 |
96 | self.segmentation = SpeakerSegmentation(
97 | self._config.segmentation, self._config.device
98 | )
99 | self.embedding = OverlapAwareSpeakerEmbedding(
100 | self._config.embedding,
101 | self._config.gamma,
102 | self._config.beta,
103 | norm=1,
104 | normalize_weights=self._config.normalize_embedding_weights,
105 | device=self._config.device,
106 | )
107 | self.pred_aggregation = DelayedAggregation(
108 | self._config.step,
109 | self._config.latency,
110 | strategy="hamming",
111 | cropping_mode="loose",
112 | )
113 | self.audio_aggregation = DelayedAggregation(
114 | self._config.step,
115 | self._config.latency,
116 | strategy="first",
117 | cropping_mode="center",
118 | )
119 | self.binarize = Binarize(self._config.tau_active)
120 |
121 | # Internal state, handle with care
122 | self.timestamp_shift = 0
123 | self.clustering = None
124 | self.chunk_buffer, self.pred_buffer = [], []
125 | self.reset()
126 |
127 | @staticmethod
128 | def get_config_class() -> type:
129 | return SpeakerDiarizationConfig
130 |
131 | @staticmethod
132 | def suggest_metric() -> BaseMetric:
133 | return DiarizationErrorRate(collar=0, skip_overlap=False)
134 |
135 | @staticmethod
136 | def hyper_parameters() -> Sequence[base.HyperParameter]:
137 | return [base.TauActive, base.RhoUpdate, base.DeltaNew]
138 |
139 | @property
140 | def config(self) -> SpeakerDiarizationConfig:
141 | return self._config
142 |
143 | def set_timestamp_shift(self, shift: float):
144 | self.timestamp_shift = shift
145 |
146 | def reset(self):
147 | self.set_timestamp_shift(0)
148 | self.clustering = OnlineSpeakerClustering(
149 | self.config.tau_active,
150 | self.config.rho_update,
151 | self.config.delta_new,
152 | "cosine",
153 | self.config.max_speakers,
154 | )
155 | self.chunk_buffer, self.pred_buffer = [], []
156 |
157 | def __call__(
158 | self, waveforms: Sequence[SlidingWindowFeature]
159 | ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
160 | """Diarize the next audio chunks of an audio stream.
161 |
162 | Parameters
163 | ----------
164 | waveforms: Sequence[SlidingWindowFeature]
165 | A sequence of consecutive audio chunks from an audio stream.
166 |
167 | Returns
168 | -------
169 | Sequence[tuple[Annotation, SlidingWindowFeature]]
170 | Speaker diarization of each chunk alongside their corresponding audio.
171 | """
172 | batch_size = len(waveforms)
173 | msg = "Pipeline expected at least 1 input"
174 | assert batch_size >= 1, msg
175 |
176 | # Create batch from chunk sequence, shape (batch, samples, channels)
177 | batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
178 |
179 | expected_num_samples = int(
180 | np.rint(self.config.duration * self.config.sample_rate)
181 | )
182 | msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
183 | assert batch.shape[1] == expected_num_samples, msg
184 |
185 | # Extract segmentation and embeddings
186 | segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
187 | # embeddings has shape (batch, speakers, emb_dim)
188 | embeddings = self.embedding(batch, segmentations)
189 |
190 | seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
191 |
192 | outputs = []
193 | for wav, seg, emb in zip(waveforms, segmentations, embeddings):
194 | # Add timestamps to segmentation
195 | sw = SlidingWindow(
196 | start=wav.extent.start,
197 | duration=seg_resolution,
198 | step=seg_resolution,
199 | )
200 | seg = SlidingWindowFeature(seg.cpu().numpy(), sw)
201 |
202 | # Update clustering state and permute segmentation
203 | permuted_seg = self.clustering(seg, emb)
204 |
205 | # Update sliding buffer
206 | self.chunk_buffer.append(wav)
207 | self.pred_buffer.append(permuted_seg)
208 |
209 | # Aggregate buffer outputs for this time step
210 | agg_waveform = self.audio_aggregation(self.chunk_buffer)
211 | agg_prediction = self.pred_aggregation(self.pred_buffer)
212 | agg_prediction = self.binarize(agg_prediction)
213 |
214 | # Shift prediction timestamps if required
215 | if self.timestamp_shift != 0:
216 | shifted_agg_prediction = Annotation(agg_prediction.uri)
217 | for segment, track, speaker in agg_prediction.itertracks(
218 | yield_label=True
219 | ):
220 | new_segment = Segment(
221 | segment.start + self.timestamp_shift,
222 | segment.end + self.timestamp_shift,
223 | )
224 | shifted_agg_prediction[new_segment, track] = speaker
225 | agg_prediction = shifted_agg_prediction
226 |
227 | outputs.append((agg_prediction, agg_waveform))
228 |
229 | # Make place for new chunks in buffer if required
230 | if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
231 | self.chunk_buffer = self.chunk_buffer[1:]
232 | self.pred_buffer = self.pred_buffer[1:]
233 |
234 | return outputs
235 |
--------------------------------------------------------------------------------
/src/diart/blocks/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Text
2 |
3 | import torch
4 | from einops import rearrange
5 |
6 | from .. import functional as F
7 | from ..features import TemporalFeatures, TemporalFeatureFormatter
8 | from ..models import EmbeddingModel
9 |
10 |
11 | class SpeakerEmbedding:
12 | def __init__(self, model: EmbeddingModel, device: Optional[torch.device] = None):
13 | self.model = model
14 | self.model.eval()
15 | self.device = device
16 | if self.device is None:
17 | self.device = torch.device("cpu")
18 | self.model.to(self.device)
19 | self.waveform_formatter = TemporalFeatureFormatter()
20 | self.weights_formatter = TemporalFeatureFormatter()
21 |
22 | @staticmethod
23 | def from_pretrained(
24 | model,
25 | use_hf_token: Union[Text, bool, None] = True,
26 | device: Optional[torch.device] = None,
27 | ) -> "SpeakerEmbedding":
28 | emb_model = EmbeddingModel.from_pretrained(model, use_hf_token)
29 | return SpeakerEmbedding(emb_model, device)
30 |
31 | def __call__(
32 | self, waveform: TemporalFeatures, weights: Optional[TemporalFeatures] = None
33 | ) -> torch.Tensor:
34 | """
35 | Calculate speaker embeddings of input audio.
36 | If weights are given, calculate many speaker embeddings from the same waveform.
37 |
38 | Parameters
39 | ----------
40 | waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
41 | weights: Optional[TemporalFeatures], shape (frames, speakers) or (batch, frames, speakers)
42 | Per-speaker and per-frame weights. Defaults to no weights.
43 |
44 | Returns
45 | -------
46 | embeddings: torch.Tensor
47 | If weights are provided, the shape is (batch, speakers, embedding_dim),
48 | otherwise the shape is (batch, embedding_dim).
49 | If batch size == 1, the batch dimension is omitted.
50 | """
51 | with torch.no_grad():
52 | inputs = self.waveform_formatter.cast(waveform).to(self.device)
53 | inputs = rearrange(inputs, "batch sample channel -> batch channel sample")
54 | if weights is not None:
55 | weights = self.weights_formatter.cast(weights).to(self.device)
56 | batch_size, _, num_speakers = weights.shape
57 | inputs = inputs.repeat(1, num_speakers, 1)
58 | weights = rearrange(weights, "batch frame spk -> (batch spk) frame")
59 | inputs = rearrange(inputs, "batch spk sample -> (batch spk) 1 sample")
60 | output = rearrange(
61 | self.model(inputs, weights),
62 | "(batch spk) feat -> batch spk feat",
63 | batch=batch_size,
64 | spk=num_speakers,
65 | )
66 | else:
67 | output = self.model(inputs)
68 | return output.squeeze().cpu()
69 |
70 |
71 | class OverlappedSpeechPenalty:
72 | """Applies a penalty on overlapping speech and low-confidence regions to speaker segmentation scores.
73 |
74 | .. note::
75 | For more information, see `"Overlap-Aware Low-Latency Online Speaker Diarization
76 | based on End-to-End Local Segmentation" `_
77 | (Section 2.2.1 Segmentation-driven speaker embedding). This block implements Equation 2.
78 |
79 | Parameters
80 | ----------
81 | gamma: float, optional
82 | Exponent to lower low-confidence predictions.
83 | Defaults to 3.
84 | beta: float, optional
85 | Temperature parameter (actually 1/beta) to lower joint speaker activations.
86 | Defaults to 10.
87 | normalize: bool, optional
88 | Whether to min-max normalize weights to be in the range [0, 1].
89 | Defaults to False.
90 | """
91 |
92 | def __init__(self, gamma: float = 3, beta: float = 10, normalize: bool = False):
93 | self.gamma = gamma
94 | self.beta = beta
95 | self.formatter = TemporalFeatureFormatter()
96 | self.normalize = normalize
97 |
98 | def __call__(self, segmentation: TemporalFeatures) -> TemporalFeatures:
99 | weights = self.formatter.cast(segmentation) # shape (batch, frames, speakers)
100 | with torch.inference_mode():
101 | weights = F.overlapped_speech_penalty(weights, self.gamma, self.beta)
102 | if self.normalize:
103 | min_values = weights.min(dim=1, keepdim=True).values
104 | max_values = weights.max(dim=1, keepdim=True).values
105 | weights = (weights - min_values) / (max_values - min_values)
106 | weights.nan_to_num_(1e-8)
107 | return self.formatter.restore_type(weights)
108 |
109 |
110 | class EmbeddingNormalization:
111 | def __init__(self, norm: Union[float, torch.Tensor] = 1):
112 | self.norm = norm
113 | # Add batch dimension if missing
114 | if isinstance(self.norm, torch.Tensor) and self.norm.ndim == 2:
115 | self.norm = self.norm.unsqueeze(0)
116 |
117 | def __call__(self, embeddings: torch.Tensor) -> torch.Tensor:
118 | with torch.inference_mode():
119 | norm_embs = F.normalize_embeddings(embeddings, self.norm)
120 | return norm_embs
121 |
122 |
123 | class OverlapAwareSpeakerEmbedding:
124 | """
125 | Extract overlap-aware speaker embeddings given an audio chunk and its segmentation.
126 |
127 | Parameters
128 | ----------
129 | model: EmbeddingModel
130 | A pre-trained embedding model.
131 | gamma: float, optional
132 | Exponent to lower low-confidence predictions.
133 | Defaults to 3.
134 | beta: float, optional
135 | Softmax's temperature parameter (actually 1/beta) to lower joint speaker activations.
136 | Defaults to 10.
137 | norm: float or torch.Tensor of shape (batch, speakers, 1) where batch is optional
138 | The target norm for the embeddings. It can be different for each speaker.
139 | Defaults to 1.
140 | normalize_weights: bool, optional
141 | Whether to min-max normalize embedding weights to be in the range [0, 1].
142 | device: Optional[torch.device]
143 | The device on which to run the embedding model.
144 | Defaults to GPU if available or CPU if not.
145 | """
146 |
147 | def __init__(
148 | self,
149 | model: EmbeddingModel,
150 | gamma: float = 3,
151 | beta: float = 10,
152 | norm: Union[float, torch.Tensor] = 1,
153 | normalize_weights: bool = False,
154 | device: Optional[torch.device] = None,
155 | ):
156 | self.embedding = SpeakerEmbedding(model, device)
157 | self.osp = OverlappedSpeechPenalty(gamma, beta, normalize_weights)
158 | self.normalize = EmbeddingNormalization(norm)
159 |
160 | @staticmethod
161 | def from_pretrained(
162 | model,
163 | gamma: float = 3,
164 | beta: float = 10,
165 | norm: Union[float, torch.Tensor] = 1,
166 | use_hf_token: Union[Text, bool, None] = True,
167 | normalize_weights: bool = False,
168 | device: Optional[torch.device] = None,
169 | ):
170 | model = EmbeddingModel.from_pretrained(model, use_hf_token)
171 | return OverlapAwareSpeakerEmbedding(
172 | model, gamma, beta, norm, normalize_weights, device
173 | )
174 |
175 | def __call__(
176 | self, waveform: TemporalFeatures, segmentation: TemporalFeatures
177 | ) -> torch.Tensor:
178 | return self.normalize(self.embedding(waveform, self.osp(segmentation)))
179 |
--------------------------------------------------------------------------------
/src/diart/blocks/segmentation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Text
2 |
3 | import torch
4 | from einops import rearrange
5 |
6 | from ..features import TemporalFeatures, TemporalFeatureFormatter
7 | from ..models import SegmentationModel
8 |
9 |
10 | class SpeakerSegmentation:
11 | def __init__(self, model: SegmentationModel, device: Optional[torch.device] = None):
12 | self.model = model
13 | self.model.eval()
14 | self.device = device
15 | if self.device is None:
16 | self.device = torch.device("cpu")
17 | self.model.to(self.device)
18 | self.formatter = TemporalFeatureFormatter()
19 |
20 | @staticmethod
21 | def from_pretrained(
22 | model,
23 | use_hf_token: Union[Text, bool, None] = True,
24 | device: Optional[torch.device] = None,
25 | ) -> "SpeakerSegmentation":
26 | seg_model = SegmentationModel.from_pretrained(model, use_hf_token)
27 | return SpeakerSegmentation(seg_model, device)
28 |
29 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
30 | """
31 | Calculate the speaker segmentation of input audio.
32 |
33 | Parameters
34 | ----------
35 | waveform: TemporalFeatures, shape (samples, channels) or (batch, samples, channels)
36 |
37 | Returns
38 | -------
39 | speaker_segmentation: TemporalFeatures, shape (batch, frames, speakers)
40 | The batch dimension is omitted if waveform is a `SlidingWindowFeature`.
41 | """
42 | with torch.no_grad():
43 | wave = rearrange(
44 | self.formatter.cast(waveform),
45 | "batch sample channel -> batch channel sample",
46 | )
47 | output = self.model(wave.to(self.device)).cpu()
48 | return self.formatter.restore_type(output)
49 |
--------------------------------------------------------------------------------
/src/diart/blocks/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Text, Optional
2 |
3 | import numpy as np
4 | import torch
5 | from pyannote.core import Annotation, Segment, SlidingWindowFeature
6 | import torchaudio.transforms as T
7 |
8 | from ..features import TemporalFeatures, TemporalFeatureFormatter
9 |
10 |
11 | class Binarize:
12 | """
13 | Transform a speaker segmentation from the discrete-time domain
14 | into a continuous-time speaker segmentation.
15 |
16 | Parameters
17 | ----------
18 | threshold: float
19 | Probability threshold to determine if a speaker is active at a given frame.
20 | uri: Optional[Text]
21 | Uri of the audio stream. Defaults to no uri.
22 | """
23 |
24 | def __init__(self, threshold: float, uri: Optional[Text] = None):
25 | self.uri = uri
26 | self.threshold = threshold
27 |
28 | def __call__(self, segmentation: SlidingWindowFeature) -> Annotation:
29 | """
30 | Return the continuous-time segmentation
31 | corresponding to the discrete-time input segmentation.
32 |
33 | Parameters
34 | ----------
35 | segmentation: SlidingWindowFeature
36 | Discrete-time speaker segmentation.
37 |
38 | Returns
39 | -------
40 | annotation: Annotation
41 | Continuous-time speaker segmentation.
42 | """
43 | num_frames, num_speakers = segmentation.data.shape
44 | timestamps = segmentation.sliding_window
45 | is_active = segmentation.data > self.threshold
46 | # Artificially add last inactive frame to close any remaining speaker turns
47 | is_active = np.append(is_active, [[False] * num_speakers], axis=0)
48 | start_times = np.zeros(num_speakers) + timestamps[0].middle
49 | annotation = Annotation(uri=self.uri, modality="speech")
50 | for t in range(num_frames):
51 | # Any (False, True) starts a speaker turn at "True" index
52 | onsets = np.logical_and(np.logical_not(is_active[t]), is_active[t + 1])
53 | start_times[onsets] = timestamps[t + 1].middle
54 | # Any (True, False) ends a speaker turn at "False" index
55 | offsets = np.logical_and(is_active[t], np.logical_not(is_active[t + 1]))
56 | for spk in np.where(offsets)[0]:
57 | region = Segment(start_times[spk], timestamps[t + 1].middle)
58 | annotation[region, spk] = f"speaker{spk}"
59 | return annotation
60 |
61 |
62 | class Resample:
63 | """Dynamically resample audio chunks.
64 |
65 | Parameters
66 | ----------
67 | sample_rate: int
68 | Original sample rate of the input audio
69 | resample_rate: int
70 | Sample rate of the output
71 | """
72 |
73 | def __init__(
74 | self,
75 | sample_rate: int,
76 | resample_rate: int,
77 | device: Optional[torch.device] = None,
78 | ):
79 | self.device = device
80 | if self.device is None:
81 | self.device = torch.device("cpu")
82 | self.resample = T.Resample(sample_rate, resample_rate).to(self.device)
83 | self.formatter = TemporalFeatureFormatter()
84 |
85 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
86 | wav = self.formatter.cast(waveform).to(self.device) # shape (batch, samples, 1)
87 | with torch.no_grad():
88 | resampled_wav = self.resample(wav.transpose(-1, -2)).transpose(-1, -2)
89 | return self.formatter.restore_type(resampled_wav)
90 |
91 |
92 | class AdjustVolume:
93 | """Change the volume of an audio chunk.
94 |
95 | Notice that the output volume might be different to avoid saturation.
96 |
97 | Parameters
98 | ----------
99 | volume_in_db: float
100 | Target volume in dB.
101 | """
102 |
103 | def __init__(self, volume_in_db: float):
104 | self.target_db = volume_in_db
105 | self.formatter = TemporalFeatureFormatter()
106 |
107 | @staticmethod
108 | def get_volumes(waveforms: torch.Tensor) -> torch.Tensor:
109 | """Compute the volumes of a set of audio chunks.
110 |
111 | Parameters
112 | ----------
113 | waveforms: torch.Tensor
114 | Audio chunks. Shape (batch, samples, channels).
115 |
116 | Returns
117 | -------
118 | volumes: torch.Tensor
119 | Audio chunk volumes per channel. Shape (batch, 1, channels)
120 | """
121 | return 10 * torch.log10(
122 | torch.mean(torch.abs(waveforms) ** 2, dim=1, keepdim=True)
123 | )
124 |
125 | def __call__(self, waveform: TemporalFeatures) -> TemporalFeatures:
126 | wav = self.formatter.cast(waveform) # shape (batch, samples, channels)
127 | with torch.no_grad():
128 | # Compute current volume per chunk, shape (batch, 1, channels)
129 | current_volumes = self.get_volumes(wav)
130 | # Determine gain to reach the target volume
131 | gains = 10 ** ((self.target_db - current_volumes) / 20)
132 | # Apply gain
133 | wav = gains * wav
134 | # If maximum value is greater than one, normalize chunk
135 | maximums = torch.clamp(torch.amax(torch.abs(wav), dim=1, keepdim=True), 1)
136 | wav = wav / maximums
137 | return self.formatter.restore_type(wav)
138 |
--------------------------------------------------------------------------------
/src/diart/blocks/vad.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Sequence
4 |
5 | import numpy as np
6 | import torch
7 | from pyannote.core import (
8 | Annotation,
9 | Timeline,
10 | SlidingWindowFeature,
11 | SlidingWindow,
12 | Segment,
13 | )
14 | from pyannote.metrics.base import BaseMetric
15 | from pyannote.metrics.detection import DetectionErrorRate
16 | from typing_extensions import Literal
17 |
18 | from . import base
19 | from .aggregation import DelayedAggregation
20 | from .segmentation import SpeakerSegmentation
21 | from .utils import Binarize
22 | from .. import models as m
23 | from .. import utils
24 |
25 |
26 | class VoiceActivityDetectionConfig(base.PipelineConfig):
27 | def __init__(
28 | self,
29 | segmentation: m.SegmentationModel | None = None,
30 | duration: float = 5,
31 | step: float = 0.5,
32 | latency: float | Literal["max", "min"] | None = None,
33 | tau_active: float = 0.6,
34 | device: torch.device | None = None,
35 | sample_rate: int = 16000,
36 | **kwargs,
37 | ):
38 | # Default segmentation model is pyannote/segmentation
39 | self.segmentation = segmentation or m.SegmentationModel.from_pyannote(
40 | "pyannote/segmentation"
41 | )
42 |
43 | self._duration = duration
44 | self._step = step
45 | self._sample_rate = sample_rate
46 |
47 | # Latency defaults to the step duration
48 | self._latency = latency
49 | if self._latency is None or self._latency == "min":
50 | self._latency = self._step
51 | elif self._latency == "max":
52 | self._latency = self._duration
53 |
54 | self.tau_active = tau_active
55 | self.device = device or torch.device(
56 | "cuda" if torch.cuda.is_available() else "cpu"
57 | )
58 |
59 | @property
60 | def duration(self) -> float:
61 | return self._duration
62 |
63 | @property
64 | def step(self) -> float:
65 | return self._step
66 |
67 | @property
68 | def latency(self) -> float:
69 | return self._latency
70 |
71 | @property
72 | def sample_rate(self) -> int:
73 | return self._sample_rate
74 |
75 |
76 | class VoiceActivityDetection(base.Pipeline):
77 | def __init__(self, config: VoiceActivityDetectionConfig | None = None):
78 | self._config = VoiceActivityDetectionConfig() if config is None else config
79 |
80 | msg = f"Latency should be in the range [{self._config.step}, {self._config.duration}]"
81 | assert self._config.step <= self._config.latency <= self._config.duration, msg
82 |
83 | self.segmentation = SpeakerSegmentation(
84 | self._config.segmentation, self._config.device
85 | )
86 | self.pred_aggregation = DelayedAggregation(
87 | self._config.step,
88 | self._config.latency,
89 | strategy="hamming",
90 | cropping_mode="loose",
91 | )
92 | self.audio_aggregation = DelayedAggregation(
93 | self._config.step,
94 | self._config.latency,
95 | strategy="first",
96 | cropping_mode="center",
97 | )
98 | self.binarize = Binarize(self._config.tau_active)
99 |
100 | # Internal state, handle with care
101 | self.timestamp_shift = 0
102 | self.chunk_buffer, self.pred_buffer = [], []
103 |
104 | @staticmethod
105 | def get_config_class() -> type:
106 | return VoiceActivityDetectionConfig
107 |
108 | @staticmethod
109 | def suggest_metric() -> BaseMetric:
110 | return DetectionErrorRate(collar=0, skip_overlap=False)
111 |
112 | @staticmethod
113 | def hyper_parameters() -> Sequence[base.HyperParameter]:
114 | return [base.TauActive]
115 |
116 | @property
117 | def config(self) -> base.PipelineConfig:
118 | return self._config
119 |
120 | def reset(self):
121 | self.set_timestamp_shift(0)
122 | self.chunk_buffer, self.pred_buffer = [], []
123 |
124 | def set_timestamp_shift(self, shift: float):
125 | self.timestamp_shift = shift
126 |
127 | def __call__(
128 | self,
129 | waveforms: Sequence[SlidingWindowFeature],
130 | ) -> Sequence[tuple[Annotation, SlidingWindowFeature]]:
131 | batch_size = len(waveforms)
132 | msg = "Pipeline expected at least 1 input"
133 | assert batch_size >= 1, msg
134 |
135 | # Create batch from chunk sequence, shape (batch, samples, channels)
136 | batch = torch.stack([torch.from_numpy(w.data) for w in waveforms])
137 |
138 | expected_num_samples = int(
139 | np.rint(self.config.duration * self.config.sample_rate)
140 | )
141 | msg = f"Expected {expected_num_samples} samples per chunk, but got {batch.shape[1]}"
142 | assert batch.shape[1] == expected_num_samples, msg
143 |
144 | # Extract segmentation
145 | segmentations = self.segmentation(batch) # shape (batch, frames, speakers)
146 | voice_detection = torch.max(segmentations, dim=-1, keepdim=True)[
147 | 0
148 | ] # shape (batch, frames, 1)
149 |
150 | seg_resolution = waveforms[0].extent.duration / segmentations.shape[1]
151 |
152 | outputs = []
153 | for wav, vad in zip(waveforms, voice_detection):
154 | # Add timestamps to segmentation
155 | sw = SlidingWindow(
156 | start=wav.extent.start,
157 | duration=seg_resolution,
158 | step=seg_resolution,
159 | )
160 | vad = SlidingWindowFeature(vad.cpu().numpy(), sw)
161 |
162 | # Update sliding buffer
163 | self.chunk_buffer.append(wav)
164 | self.pred_buffer.append(vad)
165 |
166 | # Aggregate buffer outputs for this time step
167 | agg_waveform = self.audio_aggregation(self.chunk_buffer)
168 | agg_prediction = self.pred_aggregation(self.pred_buffer)
169 | agg_prediction = self.binarize(agg_prediction).get_timeline(copy=False)
170 |
171 | # Shift prediction timestamps if required
172 | if self.timestamp_shift != 0:
173 | shifted_agg_prediction = Timeline(uri=agg_prediction.uri)
174 | for segment in agg_prediction:
175 | new_segment = Segment(
176 | segment.start + self.timestamp_shift,
177 | segment.end + self.timestamp_shift,
178 | )
179 | shifted_agg_prediction.add(new_segment)
180 | agg_prediction = shifted_agg_prediction
181 |
182 | # Convert timeline into annotation with single speaker "speech"
183 | agg_prediction = agg_prediction.to_annotation(utils.repeat_label("speech"))
184 | outputs.append((agg_prediction, agg_waveform))
185 |
186 | # Make place for new chunks in buffer if required
187 | if len(self.chunk_buffer) == self.pred_aggregation.num_overlapping_windows:
188 | self.chunk_buffer = self.chunk_buffer[1:]
189 | self.pred_buffer = self.pred_buffer[1:]
190 |
191 | return outputs
192 |
--------------------------------------------------------------------------------
/src/diart/console/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/src/diart/console/__init__.py
--------------------------------------------------------------------------------
/src/diart/console/benchmark.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import pandas as pd
5 | import torch
6 |
7 | from diart import argdoc
8 | from diart import models as m
9 | from diart import utils
10 | from diart.inference import Benchmark, Parallelize
11 |
12 |
13 | def run():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "root",
17 | type=Path,
18 | help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)",
19 | )
20 | parser.add_argument(
21 | "--pipeline",
22 | default="SpeakerDiarization",
23 | type=str,
24 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
25 | )
26 | parser.add_argument(
27 | "--segmentation",
28 | default="pyannote/segmentation",
29 | type=str,
30 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
31 | )
32 | parser.add_argument(
33 | "--embedding",
34 | default="pyannote/embedding",
35 | type=str,
36 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
37 | )
38 | parser.add_argument(
39 | "--reference",
40 | type=Path,
41 | help="Optional. Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
42 | )
43 | parser.add_argument(
44 | "--duration",
45 | type=float,
46 | default=5,
47 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
48 | )
49 | parser.add_argument(
50 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
51 | )
52 | parser.add_argument(
53 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
54 | )
55 | parser.add_argument(
56 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
57 | )
58 | parser.add_argument(
59 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
60 | )
61 | parser.add_argument(
62 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
63 | )
64 | parser.add_argument(
65 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
66 | )
67 | parser.add_argument(
68 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
69 | )
70 | parser.add_argument(
71 | "--max-speakers",
72 | default=20,
73 | type=int,
74 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
75 | )
76 | parser.add_argument(
77 | "--batch-size",
78 | default=32,
79 | type=int,
80 | help=f"{argdoc.BATCH_SIZE}. Defaults to 32",
81 | )
82 | parser.add_argument(
83 | "--num-workers",
84 | default=0,
85 | type=int,
86 | help=f"{argdoc.NUM_WORKERS}. Defaults to 0 (no parallelism)",
87 | )
88 | parser.add_argument(
89 | "--cpu",
90 | dest="cpu",
91 | action="store_true",
92 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
93 | )
94 | parser.add_argument(
95 | "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing"
96 | )
97 | parser.add_argument(
98 | "--hf-token",
99 | default="true",
100 | type=str,
101 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
102 | )
103 | parser.add_argument(
104 | "--normalize-embedding-weights",
105 | action="store_true",
106 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
107 | )
108 | args = parser.parse_args()
109 |
110 | # Resolve device
111 | args.device = torch.device("cpu") if args.cpu else None
112 |
113 | # Resolve models
114 | hf_token = utils.parse_hf_token_arg(args.hf_token)
115 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
116 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
117 |
118 | pipeline_class = utils.get_pipeline_class(args.pipeline)
119 |
120 | benchmark = Benchmark(
121 | args.root,
122 | args.reference,
123 | args.output,
124 | show_progress=True,
125 | show_report=True,
126 | batch_size=args.batch_size,
127 | )
128 |
129 | config = pipeline_class.get_config_class()(**vars(args))
130 | if args.num_workers > 0:
131 | benchmark = Parallelize(benchmark, args.num_workers)
132 |
133 | report = benchmark(pipeline_class, config)
134 |
135 | if args.output is not None and isinstance(report, pd.DataFrame):
136 | report.to_csv(args.output / "benchmark_report.csv")
137 |
138 |
139 | if __name__ == "__main__":
140 | run()
141 |
--------------------------------------------------------------------------------
/src/diart/console/client.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from threading import Thread
4 | from typing import Text, Optional
5 |
6 | import rx.operators as ops
7 | from websocket import WebSocket
8 |
9 | from diart import argdoc
10 | from diart import sources as src
11 | from diart import utils
12 |
13 |
14 | def send_audio(ws: WebSocket, source: Text, step: float, sample_rate: int):
15 | # Create audio source
16 | source_components = source.split(":")
17 | if source_components[0] != "microphone":
18 | audio_source = src.FileAudioSource(source, sample_rate, block_duration=step)
19 | else:
20 | device = int(source_components[1]) if len(source_components) > 1 else None
21 | audio_source = src.MicrophoneAudioSource(step, device)
22 |
23 | # Encode audio, then send through websocket
24 | audio_source.stream.pipe(ops.map(utils.encode_audio)).subscribe_(ws.send)
25 |
26 | # Start reading audio
27 | audio_source.read()
28 |
29 |
30 | def receive_audio(ws: WebSocket, output: Optional[Path]):
31 | while True:
32 | message = ws.recv()
33 | print(f"Received: {message}", end="")
34 | if output is not None:
35 | with open(output, "a") as file:
36 | file.write(message)
37 |
38 |
39 | def run():
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument(
42 | "source",
43 | type=str,
44 | help="Path to an audio file | 'microphone' | 'microphone:'",
45 | )
46 | parser.add_argument("--host", required=True, type=str, help="Server host")
47 | parser.add_argument("--port", required=True, type=int, help="Server port")
48 | parser.add_argument(
49 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
50 | )
51 | parser.add_argument(
52 | "-sr",
53 | "--sample-rate",
54 | default=16000,
55 | type=int,
56 | help=f"{argdoc.SAMPLE_RATE}. Defaults to 16000",
57 | )
58 | parser.add_argument(
59 | "-o",
60 | "--output-file",
61 | type=Path,
62 | help="Output RTTM file. Defaults to no writing",
63 | )
64 | args = parser.parse_args()
65 |
66 | # Run websocket client
67 | ws = WebSocket()
68 | ws.connect(f"ws://{args.host}:{args.port}")
69 | sender = Thread(
70 | target=send_audio, args=[ws, args.source, args.step, args.sample_rate]
71 | )
72 | receiver = Thread(target=receive_audio, args=[ws, args.output_file])
73 | sender.start()
74 | receiver.start()
75 |
76 |
77 | if __name__ == "__main__":
78 | run()
79 |
--------------------------------------------------------------------------------
/src/diart/console/serve.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import torch
5 |
6 | from diart import argdoc
7 | from diart import models as m
8 | from diart import sources as src
9 | from diart import utils
10 | from diart.inference import StreamingInference
11 | from diart.sinks import RTTMWriter
12 |
13 |
14 | def run():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--host", default="0.0.0.0", type=str, help="Server host")
17 | parser.add_argument("--port", default=7007, type=int, help="Server port")
18 | parser.add_argument(
19 | "--pipeline",
20 | default="SpeakerDiarization",
21 | type=str,
22 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
23 | )
24 | parser.add_argument(
25 | "--segmentation",
26 | default="pyannote/segmentation",
27 | type=str,
28 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
29 | )
30 | parser.add_argument(
31 | "--embedding",
32 | default="pyannote/embedding",
33 | type=str,
34 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
35 | )
36 | parser.add_argument(
37 | "--duration",
38 | type=float,
39 | default=5,
40 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
41 | )
42 | parser.add_argument(
43 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
44 | )
45 | parser.add_argument(
46 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
47 | )
48 | parser.add_argument(
49 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
50 | )
51 | parser.add_argument(
52 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
53 | )
54 | parser.add_argument(
55 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
56 | )
57 | parser.add_argument(
58 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
59 | )
60 | parser.add_argument(
61 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
62 | )
63 | parser.add_argument(
64 | "--max-speakers",
65 | default=20,
66 | type=int,
67 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
68 | )
69 | parser.add_argument(
70 | "--cpu",
71 | dest="cpu",
72 | action="store_true",
73 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
74 | )
75 | parser.add_argument(
76 | "--output", type=Path, help=f"{argdoc.OUTPUT}. Defaults to no writing"
77 | )
78 | parser.add_argument(
79 | "--hf-token",
80 | default="true",
81 | type=str,
82 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
83 | )
84 | parser.add_argument(
85 | "--normalize-embedding-weights",
86 | action="store_true",
87 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
88 | )
89 | args = parser.parse_args()
90 |
91 | # Resolve device
92 | args.device = torch.device("cpu") if args.cpu else None
93 |
94 | # Resolve models
95 | hf_token = utils.parse_hf_token_arg(args.hf_token)
96 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
97 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
98 |
99 | # Resolve pipeline
100 | pipeline_class = utils.get_pipeline_class(args.pipeline)
101 | config = pipeline_class.get_config_class()(**vars(args))
102 | pipeline = pipeline_class(config)
103 |
104 | # Create websocket audio source
105 | audio_source = src.WebSocketAudioSource(config.sample_rate, args.host, args.port)
106 |
107 | # Run online inference
108 | inference = StreamingInference(
109 | pipeline,
110 | audio_source,
111 | batch_size=1,
112 | do_profile=False,
113 | do_plot=False,
114 | show_progress=True,
115 | )
116 |
117 | # Write to disk if required
118 | if args.output is not None:
119 | inference.attach_observers(
120 | RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")
121 | )
122 |
123 | # Send back responses as RTTM text lines
124 | inference.attach_hooks(lambda ann_wav: audio_source.send(ann_wav[0].to_rttm()))
125 |
126 | # Run server and pipeline
127 | inference()
128 |
129 |
130 | if __name__ == "__main__":
131 | run()
132 |
--------------------------------------------------------------------------------
/src/diart/console/stream.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import torch
5 |
6 | from diart import argdoc
7 | from diart import models as m
8 | from diart import sources as src
9 | from diart import utils
10 | from diart.inference import StreamingInference
11 | from diart.sinks import RTTMWriter
12 |
13 |
14 | def run():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "source",
18 | type=str,
19 | help="Path to an audio file | 'microphone' | 'microphone:'",
20 | )
21 | parser.add_argument(
22 | "--pipeline",
23 | default="SpeakerDiarization",
24 | type=str,
25 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
26 | )
27 | parser.add_argument(
28 | "--segmentation",
29 | default="pyannote/segmentation",
30 | type=str,
31 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
32 | )
33 | parser.add_argument(
34 | "--embedding",
35 | default="pyannote/embedding",
36 | type=str,
37 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
38 | )
39 | parser.add_argument(
40 | "--duration",
41 | type=float,
42 | default=5,
43 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
44 | )
45 | parser.add_argument(
46 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
47 | )
48 | parser.add_argument(
49 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
50 | )
51 | parser.add_argument(
52 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
53 | )
54 | parser.add_argument(
55 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
56 | )
57 | parser.add_argument(
58 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
59 | )
60 | parser.add_argument(
61 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
62 | )
63 | parser.add_argument(
64 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
65 | )
66 | parser.add_argument(
67 | "--max-speakers",
68 | default=20,
69 | type=int,
70 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
71 | )
72 | parser.add_argument(
73 | "--no-plot",
74 | dest="no_plot",
75 | action="store_true",
76 | help="Skip plotting for faster inference",
77 | )
78 | parser.add_argument(
79 | "--cpu",
80 | dest="cpu",
81 | action="store_true",
82 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
83 | )
84 | parser.add_argument(
85 | "--output",
86 | type=str,
87 | help=f"{argdoc.OUTPUT}. Defaults to home directory if SOURCE == 'microphone' or parent directory if SOURCE is a file",
88 | )
89 | parser.add_argument(
90 | "--hf-token",
91 | default="true",
92 | type=str,
93 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
94 | )
95 | parser.add_argument(
96 | "--normalize-embedding-weights",
97 | action="store_true",
98 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
99 | )
100 | args = parser.parse_args()
101 |
102 | # Resolve device
103 | args.device = torch.device("cpu") if args.cpu else None
104 |
105 | # Resolve models
106 | hf_token = utils.parse_hf_token_arg(args.hf_token)
107 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
108 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
109 |
110 | # Resolve pipeline
111 | pipeline_class = utils.get_pipeline_class(args.pipeline)
112 | config = pipeline_class.get_config_class()(**vars(args))
113 | pipeline = pipeline_class(config)
114 |
115 | # Manage audio source
116 | source_components = args.source.split(":")
117 | if source_components[0] != "microphone":
118 | args.source = Path(args.source).expanduser()
119 | args.output = args.source.parent if args.output is None else Path(args.output)
120 | padding = config.get_file_padding(args.source)
121 | audio_source = src.FileAudioSource(
122 | args.source, config.sample_rate, padding, config.step
123 | )
124 | pipeline.set_timestamp_shift(-padding[0])
125 | else:
126 | args.output = (
127 | Path("~/").expanduser() if args.output is None else Path(args.output)
128 | )
129 | device = int(source_components[1]) if len(source_components) > 1 else None
130 | audio_source = src.MicrophoneAudioSource(config.step, device)
131 |
132 | # Run online inference
133 | inference = StreamingInference(
134 | pipeline,
135 | audio_source,
136 | batch_size=1,
137 | do_profile=True,
138 | do_plot=not args.no_plot,
139 | show_progress=True,
140 | )
141 | inference.attach_observers(
142 | RTTMWriter(audio_source.uri, args.output / f"{audio_source.uri}.rttm")
143 | )
144 | try:
145 | inference()
146 | except KeyboardInterrupt:
147 | pass
148 |
149 |
150 | if __name__ == "__main__":
151 | run()
152 |
--------------------------------------------------------------------------------
/src/diart/console/tune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import optuna
5 | import torch
6 | from optuna.samplers import TPESampler
7 |
8 | from diart import argdoc
9 | from diart import models as m
10 | from diart import utils
11 | from diart.blocks.base import HyperParameter
12 | from diart.optim import Optimizer
13 |
14 |
15 | def run():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument(
18 | "root",
19 | type=str,
20 | help="Directory with audio files CONVERSATION.(wav|flac|m4a|...)",
21 | )
22 | parser.add_argument(
23 | "--reference",
24 | required=True,
25 | type=str,
26 | help="Directory with RTTM files CONVERSATION.rttm. Names must match audio files",
27 | )
28 | parser.add_argument(
29 | "--pipeline",
30 | default="SpeakerDiarization",
31 | type=str,
32 | help="Class of the pipeline to optimize. Defaults to 'SpeakerDiarization'",
33 | )
34 | parser.add_argument(
35 | "--segmentation",
36 | default="pyannote/segmentation",
37 | type=str,
38 | help=f"{argdoc.SEGMENTATION}. Defaults to pyannote/segmentation",
39 | )
40 | parser.add_argument(
41 | "--embedding",
42 | default="pyannote/embedding",
43 | type=str,
44 | help=f"{argdoc.EMBEDDING}. Defaults to pyannote/embedding",
45 | )
46 | parser.add_argument(
47 | "--duration",
48 | type=float,
49 | default=5,
50 | help=f"{argdoc.DURATION}. Defaults to training segmentation duration",
51 | )
52 | parser.add_argument(
53 | "--step", default=0.5, type=float, help=f"{argdoc.STEP}. Defaults to 0.5"
54 | )
55 | parser.add_argument(
56 | "--latency", default=0.5, type=float, help=f"{argdoc.LATENCY}. Defaults to 0.5"
57 | )
58 | parser.add_argument(
59 | "--tau-active", default=0.5, type=float, help=f"{argdoc.TAU}. Defaults to 0.5"
60 | )
61 | parser.add_argument(
62 | "--rho-update", default=0.3, type=float, help=f"{argdoc.RHO}. Defaults to 0.3"
63 | )
64 | parser.add_argument(
65 | "--delta-new", default=1, type=float, help=f"{argdoc.DELTA}. Defaults to 1"
66 | )
67 | parser.add_argument(
68 | "--gamma", default=3, type=float, help=f"{argdoc.GAMMA}. Defaults to 3"
69 | )
70 | parser.add_argument(
71 | "--beta", default=10, type=float, help=f"{argdoc.BETA}. Defaults to 10"
72 | )
73 | parser.add_argument(
74 | "--max-speakers",
75 | default=20,
76 | type=int,
77 | help=f"{argdoc.MAX_SPEAKERS}. Defaults to 20",
78 | )
79 | parser.add_argument(
80 | "--batch-size",
81 | default=32,
82 | type=int,
83 | help=f"{argdoc.BATCH_SIZE}. Defaults to 32",
84 | )
85 | parser.add_argument(
86 | "--cpu",
87 | dest="cpu",
88 | action="store_true",
89 | help=f"{argdoc.CPU}. Defaults to GPU if available, CPU otherwise",
90 | )
91 | parser.add_argument(
92 | "--hparams",
93 | nargs="+",
94 | default=("tau_active", "rho_update", "delta_new"),
95 | help="Hyper-parameters to optimize. Must match names in `PipelineConfig`. Defaults to tau_active, rho_update and delta_new",
96 | )
97 | parser.add_argument(
98 | "--num-iter", default=100, type=int, help="Number of optimization trials"
99 | )
100 | parser.add_argument(
101 | "--storage",
102 | type=str,
103 | help="Optuna storage string. If provided, continue a previous study instead of creating one. The database name must match the study name",
104 | )
105 | parser.add_argument("--output", type=str, help="Working directory")
106 | parser.add_argument(
107 | "--hf-token",
108 | default="true",
109 | type=str,
110 | help=f"{argdoc.HF_TOKEN}. Defaults to 'true' (required by pyannote)",
111 | )
112 | parser.add_argument(
113 | "--normalize-embedding-weights",
114 | action="store_true",
115 | help=f"{argdoc.NORMALIZE_EMBEDDING_WEIGHTS}. Defaults to False",
116 | )
117 | args = parser.parse_args()
118 |
119 | # Resolve device
120 | args.device = torch.device("cpu") if args.cpu else None
121 |
122 | # Resolve models
123 | hf_token = utils.parse_hf_token_arg(args.hf_token)
124 | args.segmentation = m.SegmentationModel.from_pretrained(args.segmentation, hf_token)
125 | args.embedding = m.EmbeddingModel.from_pretrained(args.embedding, hf_token)
126 |
127 | # Retrieve pipeline class
128 | pipeline_class = utils.get_pipeline_class(args.pipeline)
129 |
130 | # Create the base configuration for each trial
131 | base_config = pipeline_class.get_config_class()(**vars(args))
132 |
133 | # Create hyper-parameters to optimize
134 | possible_hparams = pipeline_class.hyper_parameters()
135 | hparams = [HyperParameter.from_name(name) for name in args.hparams]
136 | hparams = [hp for hp in hparams if hp in possible_hparams]
137 | if not hparams:
138 | print(
139 | f"No hyper-parameters to optimize. "
140 | f"Make sure to select one of: {', '.join([hp.name for hp in possible_hparams])}"
141 | )
142 | exit(1)
143 |
144 | # Use a custom storage if given
145 | if args.output is not None:
146 | msg = "Both `output` and `storage` were set, but only one was expected"
147 | assert args.storage is None, msg
148 | args.output = Path(args.output).expanduser()
149 | args.output.mkdir(parents=True, exist_ok=True)
150 | study_or_path = args.output
151 | elif args.storage is not None:
152 | db_name = Path(args.storage).stem
153 | study_or_path = optuna.load_study(db_name, args.storage, TPESampler())
154 | else:
155 | msg = "Please provide either `output` or `storage`"
156 | raise ValueError(msg)
157 |
158 | # Run optimization
159 | Optimizer(
160 | pipeline_class=pipeline_class,
161 | speech_path=args.root,
162 | reference_path=args.reference,
163 | study_or_path=study_or_path,
164 | batch_size=args.batch_size,
165 | hparams=hparams,
166 | base_config=base_config,
167 | )(num_iter=args.num_iter, show_progress=True)
168 |
169 |
170 | if __name__ == "__main__":
171 | run()
172 |
--------------------------------------------------------------------------------
/src/diart/features.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 | from abc import ABC, abstractmethod
3 |
4 | import numpy as np
5 | import torch
6 | from pyannote.core import SlidingWindow, SlidingWindowFeature
7 |
8 | TemporalFeatures = Union[SlidingWindowFeature, np.ndarray, torch.Tensor]
9 |
10 |
11 | class TemporalFeatureFormatterState(ABC):
12 | """
13 | Represents the recorded type of a temporal feature formatter.
14 | Its job is to transform temporal features into tensors and
15 | recover the original format on other features.
16 | """
17 |
18 | @abstractmethod
19 | def to_tensor(self, features: TemporalFeatures) -> torch.Tensor:
20 | pass
21 |
22 | @abstractmethod
23 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
24 | """
25 | Cast `features` to the representing type and remove batch dimension if required.
26 |
27 | Parameters
28 | ----------
29 | features: torch.Tensor, shape (batch, frames, dim)
30 | Batched temporal features.
31 | Returns
32 | -------
33 | new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim)
34 | """
35 | pass
36 |
37 |
38 | class SlidingWindowFeatureFormatterState(TemporalFeatureFormatterState):
39 | def __init__(self, duration: float):
40 | self.duration = duration
41 | self._cur_start_time = 0
42 |
43 | def to_tensor(self, features: SlidingWindowFeature) -> torch.Tensor:
44 | msg = "Features sliding window duration and step must be equal"
45 | assert features.sliding_window.duration == features.sliding_window.step, msg
46 | self._cur_start_time = features.sliding_window.start
47 | return torch.from_numpy(features.data)
48 |
49 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
50 | batch_size, num_frames, _ = features.shape
51 | assert batch_size == 1, "Batched SlidingWindowFeature objects are not supported"
52 | # Calculate resolution
53 | resolution = self.duration / num_frames
54 | # Temporal shift to keep track of current start time
55 | resolution = SlidingWindow(
56 | start=self._cur_start_time, duration=resolution, step=resolution
57 | )
58 | return SlidingWindowFeature(features.squeeze(dim=0).cpu().numpy(), resolution)
59 |
60 |
61 | class NumpyArrayFormatterState(TemporalFeatureFormatterState):
62 | def to_tensor(self, features: np.ndarray) -> torch.Tensor:
63 | return torch.from_numpy(features)
64 |
65 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
66 | return features.cpu().numpy()
67 |
68 |
69 | class PytorchTensorFormatterState(TemporalFeatureFormatterState):
70 | def to_tensor(self, features: torch.Tensor) -> torch.Tensor:
71 | return features
72 |
73 | def to_internal_type(self, features: torch.Tensor) -> TemporalFeatures:
74 | return features
75 |
76 |
77 | class TemporalFeatureFormatter:
78 | """
79 | Manages the typing and format of temporal features.
80 | When casting temporal features as torch.Tensor, it remembers its
81 | type and format so it can lately restore it on other temporal features.
82 | """
83 |
84 | def __init__(self):
85 | self.state: Optional[TemporalFeatureFormatterState] = None
86 |
87 | def set_state(self, features: TemporalFeatures):
88 | if isinstance(features, SlidingWindowFeature):
89 | msg = "Features sliding window duration and step must be equal"
90 | assert features.sliding_window.duration == features.sliding_window.step, msg
91 | self.state = SlidingWindowFeatureFormatterState(
92 | features.data.shape[0] * features.sliding_window.duration,
93 | )
94 | elif isinstance(features, np.ndarray):
95 | self.state = NumpyArrayFormatterState()
96 | elif isinstance(features, torch.Tensor):
97 | self.state = PytorchTensorFormatterState()
98 | else:
99 | msg = "Unknown format. Provide one of SlidingWindowFeature, numpy.ndarray, torch.Tensor"
100 | raise ValueError(msg)
101 |
102 | def cast(self, features: TemporalFeatures) -> torch.Tensor:
103 | """
104 | Transform features into a `torch.Tensor` and add batch dimension if missing.
105 |
106 | Parameters
107 | ----------
108 | features: SlidingWindowFeature or numpy.ndarray or torch.Tensor
109 | Shape (frames, dim) or (batch, frames, dim)
110 |
111 | Returns
112 | -------
113 | features: torch.Tensor, shape (batch, frames, dim)
114 | """
115 | # Set state if not initialized
116 | self.set_state(features)
117 | # Convert features to tensor
118 | data = self.state.to_tensor(features)
119 | # Make sure there's a batch dimension
120 | msg = "Temporal features must be 2D or 3D"
121 | assert data.ndim in (2, 3), msg
122 | if data.ndim == 2:
123 | data = data.unsqueeze(0)
124 | return data.float()
125 |
126 | def restore_type(self, features: torch.Tensor) -> TemporalFeatures:
127 | """
128 | Cast `features` to the internal type and remove batch dimension if required.
129 |
130 | Parameters
131 | ----------
132 | features: torch.Tensor, shape (batch, frames, dim)
133 | Batched temporal features.
134 | Returns
135 | -------
136 | new_features: SlidingWindowFeature or numpy.ndarray or torch.Tensor, shape (batch, frames, dim)
137 | """
138 | return self.state.to_internal_type(features)
139 |
--------------------------------------------------------------------------------
/src/diart/functional.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 |
5 |
6 | def overlapped_speech_penalty(
7 | segmentation: torch.Tensor, gamma: float = 3, beta: float = 10
8 | ):
9 | # segmentation has shape (batch, frames, speakers)
10 | probs = torch.softmax(beta * segmentation, dim=-1)
11 | weights = torch.pow(segmentation, gamma) * torch.pow(probs, gamma)
12 | weights[weights < 1e-8] = 1e-8
13 | return weights
14 |
15 |
16 | def normalize_embeddings(
17 | embeddings: torch.Tensor, norm: float | torch.Tensor = 1
18 | ) -> torch.Tensor:
19 | # embeddings has shape (batch, speakers, feat) or (speakers, feat)
20 | if embeddings.ndim == 2:
21 | embeddings = embeddings.unsqueeze(0)
22 | if isinstance(norm, torch.Tensor):
23 | batch_size1, num_speakers1, _ = norm.shape
24 | batch_size2, num_speakers2, _ = embeddings.shape
25 | assert batch_size1 == batch_size2 and num_speakers1 == num_speakers2
26 | emb_norm = torch.norm(embeddings, p=2, dim=-1, keepdim=True)
27 | return norm * embeddings / emb_norm
28 |
--------------------------------------------------------------------------------
/src/diart/mapping.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable, Iterable, List, Optional, Text, Tuple, Union, Dict
4 | from abc import ABC, abstractmethod
5 |
6 | import numpy as np
7 | from pyannote.core.utils.distance import cdist
8 | from scipy.optimize import linear_sum_assignment as lsap
9 |
10 |
11 | class MappingMatrixObjective(ABC):
12 | def invalid_tensor(self, shape: Union[Tuple, int]) -> np.ndarray:
13 | return np.ones(shape) * self.invalid_value
14 |
15 | def optimal_assignments(self, matrix: np.ndarray) -> List[int]:
16 | return list(lsap(matrix, self.maximize)[1])
17 |
18 | def mapped_indices(self, matrix: np.ndarray, axis: int) -> List[int]:
19 | # Entries full of invalid_value are not mapped
20 | best_values = self.best_value_fn(matrix, axis=axis)
21 | return list(np.where(best_values != self.invalid_value)[0])
22 |
23 | def hard_speaker_map(
24 | self, num_src: int, num_tgt: int, assignments: Iterable[Tuple[int, int]]
25 | ) -> SpeakerMap:
26 | """Create a hard map object where the highest cost is put
27 | everywhere except on hard assignments from ``assignments``.
28 |
29 | Parameters
30 | ----------
31 | num_src: int
32 | Number of source speakers
33 | num_tgt: int
34 | Number of target speakers
35 | assignments: Iterable[Tuple[int, int]]
36 | An iterable of tuples with two elements having the first element as the source speaker
37 | and the second element as the target speaker
38 |
39 | Returns
40 | -------
41 | SpeakerMap
42 | """
43 | mapping_matrix = self.invalid_tensor(shape=(num_src, num_tgt))
44 | for src, tgt in assignments:
45 | mapping_matrix[src, tgt] = self.best_possible_value
46 | return SpeakerMap(mapping_matrix, self)
47 |
48 | @property
49 | def invalid_value(self) -> float:
50 | # linear_sum_assignment cannot deal with np.inf,
51 | # which would be ideal. Using a big number instead.
52 | return -1e10 if self.maximize else 1e10
53 |
54 | @property
55 | @abstractmethod
56 | def maximize(self) -> bool:
57 | pass
58 |
59 | @property
60 | @abstractmethod
61 | def best_possible_value(self) -> float:
62 | pass
63 |
64 | @property
65 | @abstractmethod
66 | def best_value_fn(self) -> Callable:
67 | pass
68 |
69 |
70 | class MinimizationObjective(MappingMatrixObjective):
71 | @property
72 | def maximize(self) -> bool:
73 | return False
74 |
75 | @property
76 | def best_possible_value(self) -> float:
77 | return 0
78 |
79 | @property
80 | def best_value_fn(self) -> Callable:
81 | return np.min
82 |
83 |
84 | class MaximizationObjective(MappingMatrixObjective):
85 | def __init__(self, max_value: float = 1):
86 | self.max_value = max_value
87 |
88 | @property
89 | def maximize(self) -> bool:
90 | return True
91 |
92 | @property
93 | def best_possible_value(self) -> float:
94 | return self.max_value
95 |
96 | @property
97 | def best_value_fn(self) -> Callable:
98 | return np.max
99 |
100 |
101 | class SpeakerMapBuilder:
102 | @staticmethod
103 | def hard_map(
104 | shape: Tuple[int, int], assignments: Iterable[Tuple[int, int]], maximize: bool
105 | ) -> SpeakerMap:
106 | """Create a ``SpeakerMap`` object based on the given assignments. This is a "hard" map, meaning that the
107 | highest cost is put everywhere except on hard assignments from ``assignments``.
108 |
109 | Parameters
110 | ----------
111 | shape: Tuple[int, int])
112 | Shape of the mapping matrix
113 | assignments: Iterable[Tuple[int, int]]
114 | An iterable of tuples with two elements having the first element as the source speaker
115 | and the second element as the target speaker
116 | maximize: bool
117 | whether to use scores where higher is better (true) or where lower is better (false)
118 |
119 | Returns
120 | -------
121 | SpeakerMap
122 | """
123 | num_src, num_tgt = shape
124 | objective = MaximizationObjective if maximize else MinimizationObjective
125 | return objective().hard_speaker_map(num_src, num_tgt, assignments)
126 |
127 | @staticmethod
128 | def correlation(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap:
129 | score_matrix_per_frame = (
130 | np.stack( # (local_speakers, num_frames, global_speakers)
131 | [
132 | scores1[:, speaker : speaker + 1] * scores2
133 | for speaker in range(scores1.shape[1])
134 | ],
135 | axis=0,
136 | )
137 | )
138 | # Calculate total speech "activations" per local speaker
139 | local_speech_scores = np.sum(scores1, axis=0).reshape(-1, 1)
140 | # Calculate speaker mapping matrix
141 | # Cost matrix is the correlation divided by sum of local activations
142 | score_matrix = np.sum(score_matrix_per_frame, axis=1) / local_speech_scores
143 | # We want to maximize the correlation to calculate optimal speaker alignments
144 | return SpeakerMap(score_matrix, MaximizationObjective(max_value=1))
145 |
146 | @staticmethod
147 | def mse(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap:
148 | cost_matrix = np.stack( # (local_speakers, local_speakers)
149 | [
150 | np.square(scores1[:, speaker : speaker + 1] - scores2).mean(axis=0)
151 | for speaker in range(scores1.shape[1])
152 | ],
153 | axis=0,
154 | )
155 | # We want to minimize the MSE to calculate optimal speaker alignments
156 | return SpeakerMap(cost_matrix, MinimizationObjective())
157 |
158 | @staticmethod
159 | def mae(scores1: np.ndarray, scores2: np.ndarray) -> SpeakerMap:
160 | cost_matrix = np.stack( # (local_speakers, local_speakers)
161 | [
162 | np.abs(scores1[:, speaker : speaker + 1] - scores2).mean(axis=0)
163 | for speaker in range(scores1.shape[1])
164 | ],
165 | axis=0,
166 | )
167 | # We want to minimize the MSE to calculate optimal speaker alignments
168 | return SpeakerMap(cost_matrix, MinimizationObjective())
169 |
170 | @staticmethod
171 | def dist(
172 | embeddings1: np.ndarray, embeddings2: np.ndarray, metric: Text = "cosine"
173 | ) -> SpeakerMap:
174 | # We want to minimize the distance to calculate optimal speaker alignments
175 | dist_matrix = cdist(embeddings1, embeddings2, metric=metric)
176 | return SpeakerMap(dist_matrix, MinimizationObjective())
177 |
178 |
179 | class SpeakerMap:
180 | def __init__(self, mapping_matrix: np.ndarray, objective: MappingMatrixObjective):
181 | self.mapping_matrix = mapping_matrix
182 | self.objective = objective
183 | self.num_source_speakers = self.mapping_matrix.shape[0]
184 | self.num_target_speakers = self.mapping_matrix.shape[1]
185 | self.mapped_source_speakers = self.objective.mapped_indices(
186 | self.mapping_matrix, axis=1
187 | )
188 | self.mapped_target_speakers = self.objective.mapped_indices(
189 | self.mapping_matrix, axis=0
190 | )
191 | self._opt_assignments: Optional[List[int]] = None
192 |
193 | @property
194 | def _raw_optimal_assignments(self) -> List[int]:
195 | if self._opt_assignments is None:
196 | self._opt_assignments = self.objective.optimal_assignments(
197 | self.mapping_matrix
198 | )
199 | return self._opt_assignments
200 |
201 | @property
202 | def shape(self) -> Tuple[int, int]:
203 | return self.mapping_matrix.shape
204 |
205 | def __len__(self):
206 | return len(self.mapped_source_speakers)
207 |
208 | def __add__(self, other: SpeakerMap) -> SpeakerMap:
209 | return self.union(other)
210 |
211 | def _strict_check_valid(self, src: int, tgt: int) -> bool:
212 | return self.mapping_matrix[src, tgt] != self.objective.invalid_value
213 |
214 | def _loose_check_valid(self, src: int, tgt: int) -> bool:
215 | return self.is_source_speaker_mapped(src)
216 |
217 | def valid_assignments(
218 | self,
219 | strict: bool = False,
220 | as_array: bool = False,
221 | ) -> Union[Tuple[List[int], List[int]], Tuple[np.ndarray, np.ndarray]]:
222 | source, target = [], []
223 | val_type = "strict" if strict else "loose"
224 | is_valid = getattr(self, f"_{val_type}_check_valid")
225 | for src, tgt in enumerate(self._raw_optimal_assignments):
226 | if is_valid(src, tgt):
227 | source.append(src)
228 | target.append(tgt)
229 | if as_array:
230 | source, target = np.array(source), np.array(target)
231 | return source, target
232 |
233 | def to_dict(self, strict: bool = False) -> Dict[int, int]:
234 | return {src: tgt for src, tgt in zip(*self.valid_assignments(strict))}
235 |
236 | def to_inverse_dict(self, strict: bool = False) -> Dict[int, int]:
237 | return {tgt: src for src, tgt in zip(*self.valid_assignments(strict))}
238 |
239 | def is_source_speaker_mapped(self, source_speaker: int) -> bool:
240 | return source_speaker in self.mapped_source_speakers
241 |
242 | def is_target_speaker_mapped(self, target_speaker: int) -> bool:
243 | return target_speaker in self.mapped_target_speakers
244 |
245 | def set_source_speaker(self, src_speaker, tgt_speaker: int):
246 | # if not force:
247 | # assert not self.is_source_speaker_mapped(src_speaker)
248 | # assert not self.is_target_speaker_mapped(tgt_speaker)
249 | new_cost_matrix = np.copy(self.mapping_matrix)
250 | new_cost_matrix[src_speaker, tgt_speaker] = self.objective.best_possible_value
251 | return SpeakerMap(new_cost_matrix, self.objective)
252 |
253 | def unmap_source_speaker(self, src_speaker: int):
254 | new_cost_matrix = np.copy(self.mapping_matrix)
255 | new_cost_matrix[src_speaker] = self.objective.invalid_tensor(
256 | shape=self.num_target_speakers
257 | )
258 | return SpeakerMap(new_cost_matrix, self.objective)
259 |
260 | def unmap_threshold(self, threshold: float) -> SpeakerMap:
261 | def is_invalid(val):
262 | if self.objective.maximize:
263 | return val <= threshold
264 | else:
265 | return val >= threshold
266 |
267 | return self.unmap_speakers(
268 | [
269 | src
270 | for src, tgt in zip(*self.valid_assignments())
271 | if is_invalid(self.mapping_matrix[src, tgt])
272 | ]
273 | )
274 |
275 | def unmap_speakers(
276 | self,
277 | source_speakers: Optional[Union[List[int], np.ndarray]] = None,
278 | target_speakers: Optional[Union[List[int], np.ndarray]] = None,
279 | ) -> SpeakerMap:
280 | # Set invalid_value to disabled speakers.
281 | # If they happen to be the best mapping for a local speaker,
282 | # it means that the mapping of the local speaker should be ignored.
283 | source_speakers = [] if source_speakers is None else source_speakers
284 | target_speakers = [] if target_speakers is None else target_speakers
285 | new_cost_matrix = np.copy(self.mapping_matrix)
286 | for speaker1 in source_speakers:
287 | new_cost_matrix[speaker1] = self.objective.invalid_tensor(
288 | shape=self.num_target_speakers
289 | )
290 | for speaker2 in target_speakers:
291 | new_cost_matrix[:, speaker2] = self.objective.invalid_tensor(
292 | shape=self.num_source_speakers
293 | )
294 | return SpeakerMap(new_cost_matrix, self.objective)
295 |
296 | def compose(self, other: SpeakerMap) -> SpeakerMap:
297 | """Let's say that `self` is a mapping of `source_speakers` to `intermediate_speakers`
298 | and `other` is a mapping from `intermediate_speakers` to `target_speakers`.
299 |
300 | Compose `self` with `other` to obtain a new mapping from `source_speakers` to `target_speakers`.
301 | """
302 | new_cost_matrix = other.objective.invalid_tensor(
303 | shape=(self.num_source_speakers, other.num_target_speakers)
304 | )
305 | for src_speaker, intermediate_speaker in zip(*self.valid_assignments()):
306 | target_speaker = other.mapping_matrix[intermediate_speaker]
307 | new_cost_matrix[src_speaker] = target_speaker
308 | return SpeakerMap(new_cost_matrix, other.objective)
309 |
310 | def union(self, other: SpeakerMap):
311 | """`self` and `other` are two maps with the same dimensions.
312 | Return a new hard speaker map containing assignments in both maps.
313 |
314 | An assignment from `other` is ignored if it is in conflict with
315 | a source or target speaker from `self`.
316 |
317 | WARNING: The resulting map doesn't preserve soft assignments
318 | because `self` and `other` might have different objectives.
319 |
320 | :param other: SpeakerMap
321 | Another speaker map
322 | """
323 | assert self.shape == other.shape
324 | best_val = self.objective.best_possible_value
325 | new_cost_matrix = self.objective.invalid_tensor(self.shape)
326 | self_src, self_tgt = self.valid_assignments()
327 | other_src, other_tgt = other.valid_assignments()
328 | for src in range(self.num_source_speakers):
329 | if src in self_src:
330 | # `self` is preserved by default
331 | tgt = self_tgt[self_src.index(src)]
332 | new_cost_matrix[src, tgt] = best_val
333 | elif src in other_src:
334 | # In order to add an assignment from `other`,
335 | # the target speaker cannot be in conflict with `self`
336 | tgt = other_tgt[other_src.index(src)]
337 | if not self.is_target_speaker_mapped(tgt):
338 | new_cost_matrix[src, tgt] = best_val
339 | return SpeakerMap(new_cost_matrix, self.objective)
340 |
341 | def apply(self, source_scores: np.ndarray) -> np.ndarray:
342 | """Apply this mapping to a score matrix of source speakers
343 | to obtain the same scores aligned to target speakers.
344 |
345 | Parameters
346 | ----------
347 | source_scores : SlidingWindowFeature, (num_frames, num_source_speakers)
348 | Source speaker scores per frame.
349 |
350 | Returns
351 | -------
352 | projected_scores : SlidingWindowFeature, (num_frames, num_target_speakers)
353 | Score matrix for target speakers.
354 | """
355 | # Map local speaker scores to the most probable global speaker. Unknown scores are set to 0
356 | num_frames = source_scores.data.shape[0]
357 | projected_scores = np.zeros((num_frames, self.num_target_speakers))
358 | for src_speaker, tgt_speaker in zip(*self.valid_assignments()):
359 | projected_scores[:, tgt_speaker] = source_scores[:, src_speaker]
360 | return projected_scores
361 |
--------------------------------------------------------------------------------
/src/diart/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from abc import ABC
4 | from pathlib import Path
5 | from typing import Optional, Text, Union, Callable, List
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from requests import HTTPError
11 |
12 | try:
13 | from pyannote.audio import Model
14 | from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
15 | from pyannote.audio.utils.powerset import Powerset
16 |
17 | IS_PYANNOTE_AVAILABLE = True
18 | except ImportError:
19 | IS_PYANNOTE_AVAILABLE = False
20 |
21 | try:
22 | import onnxruntime as ort
23 |
24 | IS_ONNX_AVAILABLE = True
25 | except ImportError:
26 | IS_ONNX_AVAILABLE = False
27 |
28 |
29 | class PowersetAdapter(nn.Module):
30 | def __init__(self, segmentation_model: nn.Module):
31 | super().__init__()
32 | self.model = segmentation_model
33 | specs = self.model.specifications
34 | max_speakers_per_frame = specs.powerset_max_classes
35 | max_speakers_per_chunk = len(specs.classes)
36 | self.powerset = Powerset(max_speakers_per_chunk, max_speakers_per_frame)
37 |
38 | def forward(self, waveform: torch.Tensor) -> torch.Tensor:
39 | return self.powerset.to_multilabel(self.model(waveform))
40 |
41 |
42 | class PyannoteLoader:
43 | def __init__(self, model_info, hf_token: Union[Text, bool, None] = True):
44 | super().__init__()
45 | self.model_info = model_info
46 | self.hf_token = hf_token
47 |
48 | def __call__(self) -> Callable:
49 | try:
50 | model = Model.from_pretrained(self.model_info, use_auth_token=self.hf_token)
51 | specs = getattr(model, "specifications", None)
52 | if specs is not None and specs.powerset:
53 | model = PowersetAdapter(model)
54 | return model
55 | except HTTPError:
56 | pass
57 | except ModuleNotFoundError:
58 | pass
59 | return PretrainedSpeakerEmbedding(self.model_info, use_auth_token=self.hf_token)
60 |
61 |
62 | class ONNXLoader:
63 | def __init__(self, path: str | Path, input_names: List[str], output_name: str):
64 | super().__init__()
65 | self.path = Path(path)
66 | self.input_names = input_names
67 | self.output_name = output_name
68 |
69 | def __call__(self) -> ONNXModel:
70 | return ONNXModel(self.path, self.input_names, self.output_name)
71 |
72 |
73 | class ONNXModel:
74 | def __init__(self, path: Path, input_names: List[str], output_name: str):
75 | super().__init__()
76 | self.path = path
77 | self.input_names = input_names
78 | self.output_name = output_name
79 | self.device = torch.device("cpu")
80 | self.session = None
81 | self.recreate_session()
82 |
83 | @property
84 | def execution_provider(self) -> str:
85 | device = "CUDA" if self.device.type == "cuda" else "CPU"
86 | return f"{device}ExecutionProvider"
87 |
88 | def recreate_session(self):
89 | options = ort.SessionOptions()
90 | options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
91 | self.session = ort.InferenceSession(
92 | self.path,
93 | sess_options=options,
94 | providers=[self.execution_provider],
95 | )
96 |
97 | def to(self, device: torch.device) -> ONNXModel:
98 | if device.type != self.device.type:
99 | self.device = device
100 | self.recreate_session()
101 | return self
102 |
103 | def __call__(self, *args) -> torch.Tensor:
104 | inputs = {
105 | name: arg.cpu().numpy().astype(np.float32)
106 | for name, arg in zip(self.input_names, args)
107 | }
108 | output = self.session.run([self.output_name], inputs)[0]
109 | return torch.from_numpy(output).float().to(args[0].device)
110 |
111 |
112 | class LazyModel(ABC):
113 | def __init__(self, loader: Callable[[], Callable]):
114 | super().__init__()
115 | self.get_model = loader
116 | self.model: Optional[Callable] = None
117 |
118 | def is_in_memory(self) -> bool:
119 | """Return whether the model has been loaded into memory"""
120 | return self.model is not None
121 |
122 | def load(self):
123 | if not self.is_in_memory():
124 | self.model = self.get_model()
125 |
126 | def to(self, device: torch.device) -> LazyModel:
127 | self.load()
128 | self.model = self.model.to(device)
129 | return self
130 |
131 | def __call__(self, *args, **kwargs):
132 | self.load()
133 | return self.model(*args, **kwargs)
134 |
135 | def eval(self) -> LazyModel:
136 | self.load()
137 | if isinstance(self.model, nn.Module):
138 | self.model.eval()
139 | return self
140 |
141 |
142 | class SegmentationModel(LazyModel):
143 | """
144 | Minimal interface for a segmentation model.
145 | """
146 |
147 | @staticmethod
148 | def from_pyannote(
149 | model, use_hf_token: Union[Text, bool, None] = True
150 | ) -> "SegmentationModel":
151 | """
152 | Returns a `SegmentationModel` wrapping a pyannote model.
153 |
154 | Parameters
155 | ----------
156 | model: pyannote.PipelineModel
157 | The pyannote.audio model to fetch.
158 | use_hf_token: str | bool, optional
159 | The Huggingface access token to use when downloading the model.
160 | If True, use huggingface-cli login token.
161 | Defaults to None.
162 |
163 | Returns
164 | -------
165 | wrapper: SegmentationModel
166 | """
167 | assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
168 | return SegmentationModel(PyannoteLoader(model, use_hf_token))
169 |
170 | @staticmethod
171 | def from_onnx(
172 | model_path: Union[str, Path],
173 | input_name: str = "waveform",
174 | output_name: str = "segmentation",
175 | ) -> "SegmentationModel":
176 | assert IS_ONNX_AVAILABLE, "No ONNX installation found"
177 | return SegmentationModel(ONNXLoader(model_path, [input_name], output_name))
178 |
179 | @staticmethod
180 | def from_pretrained(
181 | model, use_hf_token: Union[Text, bool, None] = True
182 | ) -> "SegmentationModel":
183 | if isinstance(model, str) or isinstance(model, Path):
184 | if Path(model).name.endswith(".onnx"):
185 | return SegmentationModel.from_onnx(model)
186 | return SegmentationModel.from_pyannote(model, use_hf_token)
187 |
188 | def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
189 | """
190 | Call the forward pass of the segmentation model.
191 | Parameters
192 | ----------
193 | waveform: torch.Tensor, shape (batch, channels, samples)
194 | Returns
195 | -------
196 | speaker_segmentation: torch.Tensor, shape (batch, frames, speakers)
197 | """
198 | return super().__call__(waveform)
199 |
200 |
201 | class EmbeddingModel(LazyModel):
202 | """Minimal interface for an embedding model."""
203 |
204 | @staticmethod
205 | def from_pyannote(
206 | model, use_hf_token: Union[Text, bool, None] = True
207 | ) -> "EmbeddingModel":
208 | """
209 | Returns an `EmbeddingModel` wrapping a pyannote model.
210 |
211 | Parameters
212 | ----------
213 | model: pyannote.PipelineModel
214 | The pyannote.audio model to fetch.
215 | use_hf_token: str | bool, optional
216 | The Huggingface access token to use when downloading the model.
217 | If True, use huggingface-cli login token.
218 | Defaults to None.
219 |
220 | Returns
221 | -------
222 | wrapper: EmbeddingModel
223 | """
224 | assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
225 | loader = PyannoteLoader(model, use_hf_token)
226 | return EmbeddingModel(loader)
227 |
228 | @staticmethod
229 | def from_onnx(
230 | model_path: Union[str, Path],
231 | input_names: List[str] | None = None,
232 | output_name: str = "embedding",
233 | ) -> "EmbeddingModel":
234 | assert IS_ONNX_AVAILABLE, "No ONNX installation found"
235 | input_names = input_names or ["waveform", "weights"]
236 | loader = ONNXLoader(model_path, input_names, output_name)
237 | return EmbeddingModel(loader)
238 |
239 | @staticmethod
240 | def from_pretrained(
241 | model, use_hf_token: Union[Text, bool, None] = True
242 | ) -> "EmbeddingModel":
243 | if isinstance(model, str) or isinstance(model, Path):
244 | if Path(model).name.endswith(".onnx"):
245 | return EmbeddingModel.from_onnx(model)
246 | return EmbeddingModel.from_pyannote(model, use_hf_token)
247 |
248 | def __call__(
249 | self, waveform: torch.Tensor, weights: Optional[torch.Tensor] = None
250 | ) -> torch.Tensor:
251 | """
252 | Call the forward pass of an embedding model with optional weights.
253 | Parameters
254 | ----------
255 | waveform: torch.Tensor, shape (batch, channels, samples)
256 | weights: Optional[torch.Tensor], shape (batch, frames)
257 | Temporal weights for each sample in the batch. Defaults to no weights.
258 | Returns
259 | -------
260 | speaker_embeddings: torch.Tensor, shape (batch, embedding_dim)
261 | """
262 | embeddings = super().__call__(waveform, weights)
263 | if isinstance(embeddings, np.ndarray):
264 | embeddings = torch.from_numpy(embeddings)
265 | return embeddings
266 |
--------------------------------------------------------------------------------
/src/diart/operators.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, Optional, List, Any, Tuple
3 |
4 | import numpy as np
5 | import rx
6 | from pyannote.core import Annotation, SlidingWindow, SlidingWindowFeature, Segment
7 | from rx import operators as ops
8 | from rx.core import Observable
9 |
10 | Operator = Callable[[Observable], Observable]
11 |
12 |
13 | @dataclass
14 | class AudioBufferState:
15 | chunk: Optional[np.ndarray]
16 | buffer: Optional[np.ndarray]
17 | start_time: float
18 | changed: bool
19 |
20 | @staticmethod
21 | def initial():
22 | return AudioBufferState(None, None, 0, False)
23 |
24 | @staticmethod
25 | def has_samples(num_samples: int):
26 | def call_fn(state) -> bool:
27 | return state.chunk is not None and state.chunk.shape[1] == num_samples
28 |
29 | return call_fn
30 |
31 | @staticmethod
32 | def to_sliding_window(sample_rate: int):
33 | def call_fn(state) -> SlidingWindowFeature:
34 | resolution = SlidingWindow(
35 | start=state.start_time,
36 | duration=1.0 / sample_rate,
37 | step=1.0 / sample_rate,
38 | )
39 | return SlidingWindowFeature(state.chunk.T, resolution)
40 |
41 | return call_fn
42 |
43 |
44 | def rearrange_audio_stream(
45 | duration: float = 5, step: float = 0.5, sample_rate: int = 16000
46 | ) -> Operator:
47 | chunk_samples = int(round(sample_rate * duration))
48 | step_samples = int(round(sample_rate * step))
49 |
50 | # FIXME this should flush buffer contents when the audio stops being emitted.
51 | # Right now this can be solved by using a block size that's a dividend of the step size.
52 |
53 | def accumulate(state: AudioBufferState, value: np.ndarray):
54 | # State contains the last emitted chunk, the current step buffer and the last starting time
55 | if value.ndim != 2 or value.shape[0] != 1:
56 | raise ValueError(
57 | f"Waveform must have shape (1, samples) but {value.shape} was found"
58 | )
59 | start_time = state.start_time
60 |
61 | # Add new samples to the buffer
62 | buffer = (
63 | value
64 | if state.buffer is None
65 | else np.concatenate([state.buffer, value], axis=1)
66 | )
67 |
68 | # Check for buffer overflow
69 | if buffer.shape[1] >= step_samples:
70 | # Pop samples from buffer
71 | if buffer.shape[1] == step_samples:
72 | new_chunk, new_buffer = buffer, None
73 | else:
74 | new_chunk = buffer[:, :step_samples]
75 | new_buffer = buffer[:, step_samples:]
76 |
77 | # Add samples to next chunk
78 | if state.chunk is not None:
79 | new_chunk = np.concatenate([state.chunk, new_chunk], axis=1)
80 |
81 | # Truncate chunk to ensure a fixed duration
82 | if new_chunk.shape[1] > chunk_samples:
83 | new_chunk = new_chunk[:, -chunk_samples:]
84 | start_time += step
85 |
86 | # Chunk has changed because of buffer overflow
87 | return AudioBufferState(new_chunk, new_buffer, start_time, changed=True)
88 |
89 | # Chunk has not changed
90 | return AudioBufferState(state.chunk, buffer, start_time, changed=False)
91 |
92 | return rx.pipe(
93 | # Accumulate last <=duration seconds of waveform as an AudioBufferState
94 | ops.scan(accumulate, AudioBufferState.initial()),
95 | # Take only states that have the desired duration and whose chunk has changed since last time
96 | ops.filter(AudioBufferState.has_samples(chunk_samples)),
97 | ops.filter(lambda state: state.changed),
98 | # Transform state into a SlidingWindowFeature containing the new chunk
99 | ops.map(AudioBufferState.to_sliding_window(sample_rate)),
100 | )
101 |
102 |
103 | def buffer_slide(n: int):
104 | def accumulate(state: List[Any], value: Any) -> List[Any]:
105 | new_state = [*state, value]
106 | if len(new_state) > n:
107 | return new_state[1:]
108 | return new_state
109 |
110 | return rx.pipe(ops.scan(accumulate, []))
111 |
112 |
113 | @dataclass
114 | class PredictionWithAudio:
115 | prediction: Annotation
116 | waveform: Optional[SlidingWindowFeature] = None
117 |
118 | @property
119 | def has_audio(self) -> bool:
120 | return self.waveform is not None
121 |
122 |
123 | @dataclass
124 | class OutputAccumulationState:
125 | annotation: Optional[Annotation]
126 | waveform: Optional[SlidingWindowFeature]
127 | real_time: float
128 | next_sample: Optional[int]
129 |
130 | @staticmethod
131 | def initial() -> "OutputAccumulationState":
132 | return OutputAccumulationState(None, None, 0, 0)
133 |
134 | @property
135 | def cropped_waveform(self) -> SlidingWindowFeature:
136 | return SlidingWindowFeature(
137 | self.waveform[: self.next_sample],
138 | self.waveform.sliding_window,
139 | )
140 |
141 | def to_tuple(
142 | self,
143 | ) -> Tuple[Optional[Annotation], Optional[SlidingWindowFeature], float]:
144 | return self.annotation, self.cropped_waveform, self.real_time
145 |
146 |
147 | def accumulate_output(
148 | duration: float,
149 | step: float,
150 | patch_collar: float = 0.05,
151 | ) -> Operator:
152 | """Accumulate predictions and audio to infinity: O(N) space complexity.
153 | Uses a pre-allocated buffer that doubles its size once full: O(logN) concat operations.
154 |
155 | Parameters
156 | ----------
157 | duration: float
158 | Buffer duration in seconds.
159 | step: float
160 | Duration of the chunks at each event in seconds.
161 | The first chunk may be bigger given the latency.
162 | patch_collar: float, optional
163 | Collar to merge speaker turns of the same speaker, in seconds.
164 | Defaults to 0.05 (i.e. 50ms).
165 | Returns
166 | -------
167 | A reactive x operator implementing this behavior.
168 | """
169 |
170 | def accumulate(
171 | state: OutputAccumulationState,
172 | value: Tuple[Annotation, Optional[SlidingWindowFeature]],
173 | ) -> OutputAccumulationState:
174 | value = PredictionWithAudio(*value)
175 | annotation, waveform = None, None
176 |
177 | # Determine the real time of the stream
178 | real_time = duration if state.annotation is None else state.real_time + step
179 |
180 | # Update total annotation with current predictions
181 | if state.annotation is None:
182 | annotation = value.prediction
183 | else:
184 | annotation = state.annotation.update(value.prediction).support(patch_collar)
185 |
186 | # Update total waveform if there's audio in the input
187 | new_next_sample = 0
188 | if value.has_audio:
189 | num_new_samples = value.waveform.data.shape[0]
190 | new_next_sample = state.next_sample + num_new_samples
191 | sw_holder = state
192 | if state.waveform is None:
193 | # Initialize the audio buffer with 10 times the size of the first chunk
194 | waveform, sw_holder = np.zeros((10 * num_new_samples, 1)), value
195 | elif new_next_sample < state.waveform.data.shape[0]:
196 | # The buffer still has enough space to accommodate the chunk
197 | waveform = state.waveform.data
198 | else:
199 | # The buffer is full, double its size
200 | waveform = np.concatenate(
201 | (state.waveform.data, np.zeros_like(state.waveform.data)), axis=0
202 | )
203 | # Copy chunk into buffer
204 | waveform[state.next_sample : new_next_sample] = value.waveform.data
205 | waveform = SlidingWindowFeature(waveform, sw_holder.waveform.sliding_window)
206 |
207 | return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
208 |
209 | return rx.pipe(
210 | ops.scan(accumulate, OutputAccumulationState.initial()),
211 | ops.map(OutputAccumulationState.to_tuple),
212 | )
213 |
214 |
215 | def buffer_output(
216 | duration: float,
217 | step: float,
218 | latency: float,
219 | sample_rate: int,
220 | patch_collar: float = 0.05,
221 | ) -> Operator:
222 | """Store last predictions and audio inside a fixed buffer.
223 | Provides the best time/space complexity trade-off if the past data is not needed.
224 |
225 | Parameters
226 | ----------
227 | duration: float
228 | Buffer duration in seconds.
229 | step: float
230 | Duration of the chunks at each event in seconds.
231 | The first chunk may be bigger given the latency.
232 | latency: float
233 | Latency of the system in seconds.
234 | sample_rate: int
235 | Sample rate of the audio source.
236 | patch_collar: float, optional
237 | Collar to merge speaker turns of the same speaker, in seconds.
238 | Defaults to 0.05 (i.e. 50ms).
239 |
240 | Returns
241 | -------
242 | A reactive x operator implementing this behavior.
243 | """
244 | # Define some useful constants
245 | num_samples = int(round(duration * sample_rate))
246 | num_step_samples = int(round(step * sample_rate))
247 | resolution = 1 / sample_rate
248 |
249 | def accumulate(
250 | state: OutputAccumulationState,
251 | value: Tuple[Annotation, Optional[SlidingWindowFeature]],
252 | ) -> OutputAccumulationState:
253 | value = PredictionWithAudio(*value)
254 | annotation, waveform = None, None
255 |
256 | # Determine the real time of the stream and the start time of the buffer
257 | real_time = duration if state.annotation is None else state.real_time + step
258 | start_time = max(0.0, real_time - latency - duration)
259 |
260 | # Update annotation and constrain its bounds to the buffer
261 | if state.annotation is None:
262 | annotation = value.prediction
263 | else:
264 | annotation = state.annotation.update(value.prediction).support(patch_collar)
265 | if start_time > 0:
266 | annotation = annotation.extrude(Segment(0, start_time))
267 |
268 | # Update the audio buffer if there's audio in the input
269 | new_next_sample = state.next_sample + num_step_samples
270 | if value.has_audio:
271 | if state.waveform is None:
272 | # Determine the size of the first chunk
273 | expected_duration = duration + step - latency
274 | expected_samples = int(round(expected_duration * sample_rate))
275 | # Shift indicator to start copying new audio in the buffer
276 | new_next_sample = state.next_sample + expected_samples
277 | # Buffer size is duration + step
278 | waveform = np.zeros((num_samples + num_step_samples, 1))
279 | # Copy first chunk into buffer (slicing because of rounding errors)
280 | waveform[:expected_samples] = value.waveform.data[:expected_samples]
281 | elif state.next_sample <= num_samples:
282 | # The buffer isn't full, copy into next free buffer chunk
283 | waveform = state.waveform.data
284 | waveform[state.next_sample : new_next_sample] = value.waveform.data
285 | else:
286 | # The buffer is full, shift values to the left and copy into last buffer chunk
287 | waveform = np.roll(state.waveform.data, -num_step_samples, axis=0)
288 | # If running on a file, the online prediction may be shorter depending on the latency
289 | # The remaining audio at the end is appended, so value.waveform may be longer than num_step_samples
290 | # In that case, we simply ignore the appended samples.
291 | waveform[-num_step_samples:] = value.waveform.data[:num_step_samples]
292 |
293 | # Wrap waveform in a sliding window feature to include timestamps
294 | window = SlidingWindow(
295 | start=start_time, duration=resolution, step=resolution
296 | )
297 | waveform = SlidingWindowFeature(waveform, window)
298 |
299 | return OutputAccumulationState(annotation, waveform, real_time, new_next_sample)
300 |
301 | return rx.pipe(
302 | ops.scan(accumulate, OutputAccumulationState.initial()),
303 | ops.map(OutputAccumulationState.to_tuple),
304 | )
305 |
--------------------------------------------------------------------------------
/src/diart/optim.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from pathlib import Path
3 | from typing import Sequence, Text, Optional, Union
4 |
5 | from optuna import TrialPruned, Study, create_study
6 | from optuna.samplers import TPESampler
7 | from optuna.trial import Trial, FrozenTrial
8 | from pyannote.metrics.base import BaseMetric
9 | from tqdm import trange, tqdm
10 | from typing_extensions import Literal
11 |
12 | from . import blocks
13 | from .audio import FilePath
14 | from .inference import Benchmark
15 |
16 |
17 | class Optimizer:
18 | def __init__(
19 | self,
20 | pipeline_class: type,
21 | speech_path: Union[Text, Path],
22 | reference_path: Union[Text, Path],
23 | study_or_path: Union[FilePath, Study],
24 | batch_size: int = 32,
25 | hparams: Optional[Sequence[blocks.base.HyperParameter]] = None,
26 | base_config: Optional[blocks.PipelineConfig] = None,
27 | do_kickstart_hparams: bool = True,
28 | metric: Optional[BaseMetric] = None,
29 | direction: Literal["minimize", "maximize"] = "minimize",
30 | ):
31 | self.pipeline_class = pipeline_class
32 | # FIXME can we run this benchmark in parallel?
33 | # Currently it breaks the trial progress bar
34 | self.benchmark = Benchmark(
35 | speech_path,
36 | reference_path,
37 | show_progress=True,
38 | show_report=False,
39 | batch_size=batch_size,
40 | )
41 |
42 | self.metric = metric
43 | self.direction = direction
44 | self.base_config = base_config
45 | self.do_kickstart_hparams = do_kickstart_hparams
46 | if self.base_config is None:
47 | self.base_config = self.pipeline_class.get_config_class()()
48 | self.do_kickstart_hparams = False
49 |
50 | self.hparams = hparams
51 | if self.hparams is None:
52 | self.hparams = self.pipeline_class.hyper_parameters()
53 |
54 | # Make sure hyper-parameters exist in the configuration class given
55 | possible_hparams = vars(self.base_config)
56 | for param in self.hparams:
57 | msg = (
58 | f"Hyper-parameter {param.name} not found "
59 | f"in configuration {self.base_config.__class__.__name__}"
60 | )
61 | assert param.name in possible_hparams, msg
62 |
63 | self._progress: Optional[tqdm] = None
64 |
65 | if isinstance(study_or_path, Study):
66 | self.study = study_or_path
67 | elif isinstance(study_or_path, str) or isinstance(study_or_path, Path):
68 | study_or_path = Path(study_or_path)
69 | self.study = create_study(
70 | storage="sqlite:///" + str(study_or_path / f"{study_or_path.stem}.db"),
71 | sampler=TPESampler(),
72 | study_name=study_or_path.stem,
73 | direction=self.direction,
74 | load_if_exists=True,
75 | )
76 | else:
77 | msg = f"Expected Study object or path-like, but got {type(study_or_path).__name__}"
78 | raise ValueError(msg)
79 |
80 | @property
81 | def best_performance(self):
82 | return self.study.best_value
83 |
84 | @property
85 | def best_hparams(self):
86 | return self.study.best_params
87 |
88 | def _callback(self, study: Study, trial: FrozenTrial):
89 | if self._progress is None:
90 | return
91 | self._progress.update(1)
92 | self._progress.set_description(f"Trial {trial.number + 1}")
93 | values = {"best_perf": study.best_value}
94 | for name, value in study.best_params.items():
95 | values[f"best_{name}"] = value
96 | self._progress.set_postfix(OrderedDict(values))
97 |
98 | def objective(self, trial: Trial) -> float:
99 | # Set suggested values for optimized hyper-parameters
100 | trial_config = vars(self.base_config)
101 | for hparam in self.hparams:
102 | trial_config[hparam.name] = trial.suggest_uniform(
103 | hparam.name, hparam.low, hparam.high
104 | )
105 |
106 | # Prune trial if required
107 | if trial.should_prune():
108 | raise TrialPruned()
109 |
110 | # Instantiate the new configuration for the trial
111 | config = self.base_config.__class__(**trial_config)
112 |
113 | # Determine the evaluation metric
114 | metric = self.metric
115 | if metric is None:
116 | metric = self.pipeline_class.suggest_metric()
117 |
118 | # Run pipeline over the dataset
119 | report = self.benchmark(self.pipeline_class, config, metric)
120 |
121 | # Extract target metric from report
122 | return report.loc["TOTAL", metric.name]["%"]
123 |
124 | def __call__(self, num_iter: int, show_progress: bool = True):
125 | self._progress = None
126 | if show_progress:
127 | self._progress = trange(num_iter)
128 | last_trial = -1
129 | if self.study.trials:
130 | last_trial = self.study.trials[-1].number
131 | self._progress.set_description(f"Trial {last_trial + 1}")
132 | # Start with base config hyper-parameters if config was given
133 | if self.do_kickstart_hparams:
134 | self.study.enqueue_trial(
135 | {
136 | param.name: getattr(self.base_config, param.name)
137 | for param in self.hparams
138 | },
139 | skip_if_exists=True,
140 | )
141 | self.study.optimize(self.objective, num_iter, callbacks=[self._callback])
142 |
--------------------------------------------------------------------------------
/src/diart/progress.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional, Text
3 |
4 | import rich
5 | from rich.progress import Progress, TaskID
6 | from tqdm import tqdm
7 |
8 |
9 | class ProgressBar(ABC):
10 | @abstractmethod
11 | def create(
12 | self,
13 | total: int,
14 | description: Optional[Text] = None,
15 | unit: Text = "it",
16 | **kwargs,
17 | ):
18 | pass
19 |
20 | @abstractmethod
21 | def start(self):
22 | pass
23 |
24 | @abstractmethod
25 | def update(self, n: int = 1):
26 | pass
27 |
28 | @abstractmethod
29 | def write(self, text: Text):
30 | pass
31 |
32 | @abstractmethod
33 | def stop(self):
34 | pass
35 |
36 | @abstractmethod
37 | def close(self):
38 | pass
39 |
40 | @property
41 | @abstractmethod
42 | def default_description(self) -> Text:
43 | pass
44 |
45 | @property
46 | @abstractmethod
47 | def initial_description(self) -> Optional[Text]:
48 | pass
49 |
50 | def resolve_description(self, new_description: Optional[Text] = None) -> Text:
51 | if self.initial_description is None:
52 | if new_description is None:
53 | return self.default_description
54 | return new_description
55 | else:
56 | return self.initial_description
57 |
58 |
59 | class RichProgressBar(ProgressBar):
60 | def __init__(
61 | self,
62 | description: Optional[Text] = None,
63 | color: Text = "green",
64 | leave: bool = True,
65 | do_close: bool = True,
66 | ):
67 | self.description = description
68 | self.color = color
69 | self.do_close = do_close
70 | self.bar = Progress(transient=not leave)
71 | self.bar.start()
72 | self.task_id: Optional[TaskID] = None
73 |
74 | @property
75 | def default_description(self) -> Text:
76 | return f"[{self.color}]Streaming"
77 |
78 | @property
79 | def initial_description(self) -> Optional[Text]:
80 | if self.description is not None:
81 | return f"[{self.color}]{self.description}"
82 | return self.description
83 |
84 | def create(
85 | self,
86 | total: int,
87 | description: Optional[Text] = None,
88 | unit: Text = "it",
89 | **kwargs,
90 | ):
91 | if self.task_id is None:
92 | self.task_id = self.bar.add_task(
93 | self.resolve_description(f"[{self.color}]{description}"),
94 | start=False,
95 | total=total,
96 | completed=0,
97 | visible=True,
98 | **kwargs,
99 | )
100 |
101 | def start(self):
102 | assert self.task_id is not None
103 | self.bar.start_task(self.task_id)
104 |
105 | def update(self, n: int = 1):
106 | assert self.task_id is not None
107 | self.bar.update(self.task_id, advance=n)
108 |
109 | def write(self, text: Text):
110 | rich.print(text)
111 |
112 | def stop(self):
113 | assert self.task_id is not None
114 | self.bar.stop_task(self.task_id)
115 |
116 | def close(self):
117 | if self.do_close:
118 | self.bar.stop()
119 |
120 |
121 | class TQDMProgressBar(ProgressBar):
122 | def __init__(
123 | self,
124 | description: Optional[Text] = None,
125 | leave: bool = True,
126 | position: Optional[int] = None,
127 | do_close: bool = True,
128 | ):
129 | self.description = description
130 | self.leave = leave
131 | self.position = position
132 | self.do_close = do_close
133 | self.pbar: Optional[tqdm] = None
134 |
135 | @property
136 | def default_description(self) -> Text:
137 | return "Streaming"
138 |
139 | @property
140 | def initial_description(self) -> Optional[Text]:
141 | return self.description
142 |
143 | def create(
144 | self,
145 | total: int,
146 | description: Optional[Text] = None,
147 | unit: Optional[Text] = "it",
148 | **kwargs,
149 | ):
150 | if self.pbar is None:
151 | self.pbar = tqdm(
152 | desc=self.resolve_description(description),
153 | total=total,
154 | unit=unit,
155 | leave=self.leave,
156 | position=self.position,
157 | **kwargs,
158 | )
159 |
160 | def start(self):
161 | pass
162 |
163 | def update(self, n: int = 1):
164 | assert self.pbar is not None
165 | self.pbar.update(n)
166 |
167 | def write(self, text: Text):
168 | tqdm.write(text)
169 |
170 | def stop(self):
171 | self.close()
172 |
173 | def close(self):
174 | if self.do_close:
175 | assert self.pbar is not None
176 | self.pbar.close()
177 |
--------------------------------------------------------------------------------
/src/diart/sinks.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Union, Text, Optional, Tuple
3 |
4 | import matplotlib.pyplot as plt
5 | from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
6 | from pyannote.database.util import load_rttm
7 | from pyannote.metrics.diarization import DiarizationErrorRate
8 | from rx.core import Observer
9 | from typing_extensions import Literal
10 |
11 |
12 | class WindowClosedException(Exception):
13 | pass
14 |
15 |
16 | def _extract_prediction(value: Union[Tuple, Annotation]) -> Annotation:
17 | if isinstance(value, tuple):
18 | return value[0]
19 | if isinstance(value, Annotation):
20 | return value
21 | msg = f"Expected tuple or Annotation, but got {type(value)}"
22 | raise ValueError(msg)
23 |
24 |
25 | class RTTMWriter(Observer):
26 | def __init__(self, uri: Text, path: Union[Path, Text], patch_collar: float = 0.05):
27 | super().__init__()
28 | self.uri = uri
29 | self.patch_collar = patch_collar
30 | self.path = Path(path).expanduser()
31 | if self.path.exists():
32 | self.path.unlink()
33 |
34 | def patch(self):
35 | """Stitch same-speaker turns that are close to each other"""
36 | if not self.path.exists():
37 | return
38 | annotations = list(load_rttm(self.path).values())
39 | if annotations:
40 | annotation = annotations[0]
41 | annotation.uri = self.uri
42 | with open(self.path, "w") as file:
43 | annotation.support(self.patch_collar).write_rttm(file)
44 |
45 | def on_next(self, value: Union[Tuple, Annotation]):
46 | prediction = _extract_prediction(value)
47 | # Write prediction in RTTM format
48 | prediction.uri = self.uri
49 | with open(self.path, "a") as file:
50 | prediction.write_rttm(file)
51 |
52 | def on_error(self, error: Exception):
53 | self.patch()
54 |
55 | def on_completed(self):
56 | self.patch()
57 |
58 |
59 | class PredictionAccumulator(Observer):
60 | def __init__(self, uri: Optional[Text] = None, patch_collar: float = 0.05):
61 | super().__init__()
62 | self.uri = uri
63 | self.patch_collar = patch_collar
64 | self._prediction: Optional[Annotation] = None
65 |
66 | def patch(self):
67 | """Stitch same-speaker turns that are close to each other"""
68 | if self._prediction is not None:
69 | self._prediction = self._prediction.support(self.patch_collar)
70 |
71 | def get_prediction(self) -> Annotation:
72 | # Patch again in case this is called before on_completed
73 | self.patch()
74 | return self._prediction
75 |
76 | def on_next(self, value: Union[Tuple, Annotation]):
77 | prediction = _extract_prediction(value)
78 | prediction.uri = self.uri
79 | if self._prediction is None:
80 | self._prediction = prediction
81 | else:
82 | self._prediction.update(prediction)
83 |
84 | def on_error(self, error: Exception):
85 | self.patch()
86 |
87 | def on_completed(self):
88 | self.patch()
89 |
90 |
91 | class StreamingPlot(Observer):
92 | def __init__(
93 | self,
94 | duration: float,
95 | latency: float,
96 | visualization: Literal["slide", "accumulate"] = "slide",
97 | reference: Optional[Union[Path, Text]] = None,
98 | ):
99 | super().__init__()
100 | assert visualization in ["slide", "accumulate"]
101 | self.visualization = visualization
102 | self.reference = reference
103 | if self.reference is not None:
104 | self.reference = list(load_rttm(reference).values())[0]
105 | self.window_duration = duration
106 | self.latency = latency
107 | self.figure, self.axs, self.num_axs = None, None, -1
108 | # This flag allows to catch the matplotlib window closed event and make the next call stop iterating
109 | self.window_closed = False
110 |
111 | def _on_window_closed(self, event):
112 | self.window_closed = True
113 |
114 | def _init_num_axs(self):
115 | if self.num_axs == -1:
116 | self.num_axs = 2
117 | if self.reference is not None:
118 | self.num_axs += 1
119 |
120 | def _init_figure(self):
121 | self._init_num_axs()
122 | self.figure, self.axs = plt.subplots(
123 | self.num_axs, 1, figsize=(10, 2 * self.num_axs)
124 | )
125 | if self.num_axs == 1:
126 | self.axs = [self.axs]
127 | self.figure.canvas.mpl_connect("close_event", self._on_window_closed)
128 |
129 | def _clear_axs(self):
130 | for i in range(self.num_axs):
131 | self.axs[i].clear()
132 |
133 | def get_plot_bounds(self, real_time: float) -> Segment:
134 | start_time = 0
135 | end_time = real_time - self.latency
136 | if self.visualization == "slide":
137 | start_time = max(0.0, end_time - self.window_duration)
138 | return Segment(start_time, end_time)
139 |
140 | def on_next(self, values: Tuple[Annotation, SlidingWindowFeature, float]):
141 | if self.window_closed:
142 | raise WindowClosedException
143 |
144 | prediction, waveform, real_time = values
145 |
146 | # Initialize figure if first call
147 | if self.figure is None:
148 | self._init_figure()
149 | # Clear previous plots
150 | self._clear_axs()
151 | # Set plot bounds
152 | notebook.crop = self.get_plot_bounds(real_time)
153 |
154 | # Align prediction and reference if possible
155 | if self.reference is not None:
156 | metric = DiarizationErrorRate()
157 | mapping = metric.optimal_mapping(self.reference, prediction)
158 | prediction.rename_labels(mapping=mapping, copy=False)
159 |
160 | # Plot prediction
161 | notebook.plot_annotation(prediction, self.axs[0])
162 | self.axs[0].set_title("Output")
163 |
164 | # Plot waveform
165 | notebook.plot_feature(waveform, self.axs[1])
166 | self.axs[1].set_title("Audio")
167 |
168 | # Plot reference if available
169 | if self.num_axs == 3:
170 | notebook.plot_annotation(self.reference, self.axs[2])
171 | self.axs[2].set_title("Reference")
172 |
173 | # Draw
174 | plt.tight_layout()
175 | self.figure.canvas.draw()
176 | self.figure.canvas.flush_events()
177 | plt.pause(0.05)
178 |
--------------------------------------------------------------------------------
/src/diart/sources.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from queue import SimpleQueue
4 | from typing import Text, Optional, AnyStr, Dict, Any, Union, Tuple
5 |
6 | import numpy as np
7 | import sounddevice as sd
8 | import torch
9 | from einops import rearrange
10 | from rx.subject import Subject
11 | from torchaudio.io import StreamReader
12 | from websocket_server import WebsocketServer
13 |
14 | from . import utils
15 | from .audio import FilePath, AudioLoader
16 |
17 |
18 | class AudioSource(ABC):
19 | """Represents a source of audio that can start streaming via the `stream` property.
20 |
21 | Parameters
22 | ----------
23 | uri: Text
24 | Unique identifier of the audio source.
25 | sample_rate: int
26 | Sample rate of the audio source.
27 | """
28 |
29 | def __init__(self, uri: Text, sample_rate: int):
30 | self.uri = uri
31 | self.sample_rate = sample_rate
32 | self.stream = Subject()
33 |
34 | @property
35 | def duration(self) -> Optional[float]:
36 | """The duration of the stream if known. Defaults to None (unknown duration)."""
37 | return None
38 |
39 | @abstractmethod
40 | def read(self):
41 | """Start reading the source and yielding samples through the stream."""
42 | pass
43 |
44 | @abstractmethod
45 | def close(self):
46 | """Stop reading the source and close all open streams."""
47 | pass
48 |
49 |
50 | class FileAudioSource(AudioSource):
51 | """Represents an audio source tied to a file.
52 |
53 | Parameters
54 | ----------
55 | file: FilePath
56 | Path to the file to stream.
57 | sample_rate: int
58 | Sample rate of the chunks emitted.
59 | padding: (float, float)
60 | Left and right padding to add to the file (in seconds).
61 | Defaults to (0, 0).
62 | block_duration: int
63 | Duration of each emitted chunk in seconds.
64 | Defaults to 0.5 seconds.
65 | """
66 |
67 | def __init__(
68 | self,
69 | file: FilePath,
70 | sample_rate: int,
71 | padding: Tuple[float, float] = (0, 0),
72 | block_duration: float = 0.5,
73 | ):
74 | super().__init__(Path(file).stem, sample_rate)
75 | self.loader = AudioLoader(self.sample_rate, mono=True)
76 | self._duration = self.loader.get_duration(file)
77 | self.file = file
78 | self.resolution = 1 / self.sample_rate
79 | self.block_size = int(np.rint(block_duration * self.sample_rate))
80 | self.padding_start, self.padding_end = padding
81 | self.is_closed = False
82 |
83 | @property
84 | def duration(self) -> Optional[float]:
85 | # The duration of a file is known
86 | return self.padding_start + self._duration + self.padding_end
87 |
88 | def read(self):
89 | """Send each chunk of samples through the stream"""
90 | waveform = self.loader.load(self.file)
91 |
92 | # Add zero padding at the beginning if required
93 | if self.padding_start > 0:
94 | num_pad_samples = int(np.rint(self.padding_start * self.sample_rate))
95 | zero_padding = torch.zeros(waveform.shape[0], num_pad_samples)
96 | waveform = torch.cat([zero_padding, waveform], dim=1)
97 |
98 | # Add zero padding at the end if required
99 | if self.padding_end > 0:
100 | num_pad_samples = int(np.rint(self.padding_end * self.sample_rate))
101 | zero_padding = torch.zeros(waveform.shape[0], num_pad_samples)
102 | waveform = torch.cat([waveform, zero_padding], dim=1)
103 |
104 | # Split into blocks
105 | _, num_samples = waveform.shape
106 | chunks = rearrange(
107 | waveform.unfold(1, self.block_size, self.block_size),
108 | "channel chunk sample -> chunk channel sample",
109 | ).numpy()
110 |
111 | # Add last incomplete chunk with padding
112 | if num_samples % self.block_size != 0:
113 | last_chunk = (
114 | waveform[:, chunks.shape[0] * self.block_size :].unsqueeze(0).numpy()
115 | )
116 | diff_samples = self.block_size - last_chunk.shape[-1]
117 | last_chunk = np.concatenate(
118 | [last_chunk, np.zeros((1, 1, diff_samples))], axis=-1
119 | )
120 | chunks = np.vstack([chunks, last_chunk])
121 |
122 | # Stream blocks
123 | for i, waveform in enumerate(chunks):
124 | try:
125 | if self.is_closed:
126 | break
127 | self.stream.on_next(waveform)
128 | except BaseException as e:
129 | self.stream.on_error(e)
130 | break
131 | self.stream.on_completed()
132 | self.close()
133 |
134 | def close(self):
135 | self.is_closed = True
136 |
137 |
138 | class MicrophoneAudioSource(AudioSource):
139 | """Audio source tied to a local microphone.
140 |
141 | Parameters
142 | ----------
143 | block_duration: int
144 | Duration of each emitted chunk in seconds.
145 | Defaults to 0.5 seconds.
146 | device: int | str | (int, str) | None
147 | Device identifier compatible for the sounddevice stream.
148 | If None, use the default device.
149 | Defaults to None.
150 | """
151 |
152 | def __init__(
153 | self,
154 | block_duration: float = 0.5,
155 | device: Optional[Union[int, Text, Tuple[int, Text]]] = None,
156 | ):
157 | # Use the lowest supported sample rate
158 | sample_rates = [16000, 32000, 44100, 48000]
159 | best_sample_rate = None
160 | for sr in sample_rates:
161 | try:
162 | sd.check_input_settings(device=device, samplerate=sr)
163 | except Exception:
164 | pass
165 | else:
166 | best_sample_rate = sr
167 | break
168 | super().__init__(f"input_device:{device}", best_sample_rate)
169 |
170 | # Determine block size in samples and create input stream
171 | self.block_size = int(np.rint(block_duration * self.sample_rate))
172 | self._mic_stream = sd.InputStream(
173 | channels=1,
174 | samplerate=self.sample_rate,
175 | latency=0,
176 | blocksize=self.block_size,
177 | callback=self._read_callback,
178 | device=device,
179 | )
180 | self._queue = SimpleQueue()
181 |
182 | def _read_callback(self, samples, *args):
183 | self._queue.put_nowait(samples[:, [0]].T)
184 |
185 | def read(self):
186 | self._mic_stream.start()
187 | while self._mic_stream:
188 | try:
189 | while self._queue.empty():
190 | if self._mic_stream.closed:
191 | break
192 | self.stream.on_next(self._queue.get_nowait())
193 | except BaseException as e:
194 | self.stream.on_error(e)
195 | break
196 | self.stream.on_completed()
197 | self.close()
198 |
199 | def close(self):
200 | self._mic_stream.stop()
201 | self._mic_stream.close()
202 |
203 |
204 | class WebSocketAudioSource(AudioSource):
205 | """Represents a source of audio coming from the network using the WebSocket protocol.
206 |
207 | Parameters
208 | ----------
209 | sample_rate: int
210 | Sample rate of the chunks emitted.
211 | host: Text
212 | The host to run the websocket server.
213 | Defaults to 127.0.0.1.
214 | port: int
215 | The port to run the websocket server.
216 | Defaults to 7007.
217 | key: Text | Path | None
218 | Path to a key if using SSL.
219 | Defaults to no key.
220 | certificate: Text | Path | None
221 | Path to a certificate if using SSL.
222 | Defaults to no certificate.
223 | """
224 |
225 | def __init__(
226 | self,
227 | sample_rate: int,
228 | host: Text = "127.0.0.1",
229 | port: int = 7007,
230 | key: Optional[Union[Text, Path]] = None,
231 | certificate: Optional[Union[Text, Path]] = None,
232 | ):
233 | # FIXME sample_rate is not being used, this can be confusing and lead to incompatibilities.
234 | # I would prefer the client to send a JSON with data and sample rate, then resample if needed
235 | super().__init__(f"{host}:{port}", sample_rate)
236 | self.client: Optional[Dict[Text, Any]] = None
237 | self.server = WebsocketServer(host, port, key=key, cert=certificate)
238 | self.server.set_fn_message_received(self._on_message_received)
239 |
240 | def _on_message_received(
241 | self,
242 | client: Dict[Text, Any],
243 | server: WebsocketServer,
244 | message: AnyStr,
245 | ):
246 | # Only one client at a time is allowed
247 | if self.client is None or self.client["id"] != client["id"]:
248 | self.client = client
249 | # Send decoded audio to pipeline
250 | self.stream.on_next(utils.decode_audio(message))
251 |
252 | def read(self):
253 | """Starts running the websocket server and listening for audio chunks"""
254 | self.server.run_forever()
255 |
256 | def close(self):
257 | """Close the websocket server"""
258 | if self.server is not None:
259 | self.stream.on_completed()
260 | self.server.shutdown_gracefully()
261 |
262 | def send(self, message: AnyStr):
263 | """Send a message through the current websocket.
264 |
265 | Parameters
266 | ----------
267 | message: AnyStr
268 | Bytes or string to send.
269 | """
270 | if len(message) > 0:
271 | self.server.send_message(self.client, message)
272 |
273 |
274 | class TorchStreamAudioSource(AudioSource):
275 | def __init__(
276 | self,
277 | uri: Text,
278 | sample_rate: int,
279 | streamer: StreamReader,
280 | stream_index: Optional[int] = None,
281 | block_duration: float = 0.5,
282 | ):
283 | super().__init__(uri, sample_rate)
284 | self.block_size = int(np.rint(block_duration * self.sample_rate))
285 | self._streamer = streamer
286 | self._streamer.add_basic_audio_stream(
287 | frames_per_chunk=self.block_size,
288 | stream_index=stream_index,
289 | format="fltp",
290 | sample_rate=self.sample_rate,
291 | )
292 | self.is_closed = False
293 |
294 | def read(self):
295 | for item in self._streamer.stream():
296 | try:
297 | if self.is_closed:
298 | break
299 | # shape (samples, channels) to (1, samples)
300 | chunk = np.mean(item[0].numpy(), axis=1, keepdims=True).T
301 | self.stream.on_next(chunk)
302 | except BaseException as e:
303 | self.stream.on_error(e)
304 | break
305 | self.stream.on_completed()
306 | self.close()
307 |
308 | def close(self):
309 | self.is_closed = True
310 |
311 |
312 | class AppleDeviceAudioSource(TorchStreamAudioSource):
313 | def __init__(
314 | self,
315 | sample_rate: int,
316 | device: str = "0:0",
317 | stream_index: int = 0,
318 | block_duration: float = 0.5,
319 | ):
320 | uri = f"apple_input_device:{device}"
321 | streamer = StreamReader(device, format="avfoundation")
322 | super().__init__(uri, sample_rate, streamer, stream_index, block_duration)
323 |
--------------------------------------------------------------------------------
/src/diart/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import time
3 | from typing import Optional, Text, Union
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from pyannote.core import Annotation, Segment, SlidingWindowFeature, notebook
8 |
9 | from . import blocks
10 | from .progress import ProgressBar
11 |
12 |
13 | class Chronometer:
14 | def __init__(self, unit: Text, progress_bar: Optional[ProgressBar] = None):
15 | self.unit = unit
16 | self.progress_bar = progress_bar
17 | self.current_start_time = None
18 | self.history = []
19 |
20 | @property
21 | def is_running(self):
22 | return self.current_start_time is not None
23 |
24 | def start(self):
25 | self.current_start_time = time.monotonic()
26 |
27 | def stop(self, do_count: bool = True):
28 | msg = "No start time available, Did you call stop() before start()?"
29 | assert self.current_start_time is not None, msg
30 | end_time = time.monotonic() - self.current_start_time
31 | self.current_start_time = None
32 | if do_count:
33 | self.history.append(end_time)
34 |
35 | def report(self):
36 | print_fn = print
37 | if self.progress_bar is not None:
38 | print_fn = self.progress_bar.write
39 | print_fn(
40 | f"Took {np.mean(self.history).item():.3f} "
41 | f"(+/-{np.std(self.history).item():.3f}) seconds/{self.unit} "
42 | f"-- ran {len(self.history)} times"
43 | )
44 |
45 |
46 | def parse_hf_token_arg(hf_token: Union[bool, Text]) -> Union[bool, Text]:
47 | if isinstance(hf_token, bool):
48 | return hf_token
49 | if hf_token.lower() == "true":
50 | return True
51 | if hf_token.lower() == "false":
52 | return False
53 | return hf_token
54 |
55 |
56 | def encode_audio(waveform: np.ndarray) -> Text:
57 | data = waveform.astype(np.float32).tobytes()
58 | return base64.b64encode(data).decode("utf-8")
59 |
60 |
61 | def decode_audio(data: Text) -> np.ndarray:
62 | # Decode chunk encoded in base64
63 | byte_samples = base64.decodebytes(data.encode("utf-8"))
64 | # Recover array from bytes
65 | samples = np.frombuffer(byte_samples, dtype=np.float32)
66 | return samples.reshape(1, -1)
67 |
68 |
69 | def get_padding_left(stream_duration: float, chunk_duration: float) -> float:
70 | if stream_duration < chunk_duration:
71 | return chunk_duration - stream_duration
72 | return 0
73 |
74 |
75 | def repeat_label(label: Text):
76 | while True:
77 | yield label
78 |
79 |
80 | def get_pipeline_class(class_name: Text) -> type:
81 | pipeline_class = getattr(blocks, class_name, None)
82 | msg = f"Pipeline '{class_name}' doesn't exist"
83 | assert pipeline_class is not None, msg
84 | return pipeline_class
85 |
86 |
87 | def get_padding_right(latency: float, step: float) -> float:
88 | return latency - step
89 |
90 |
91 | def visualize_feature(duration: Optional[float] = None):
92 | def apply(feature: SlidingWindowFeature):
93 | if duration is None:
94 | notebook.crop = feature.extent
95 | else:
96 | notebook.crop = Segment(feature.extent.end - duration, feature.extent.end)
97 | plt.rcParams["figure.figsize"] = (8, 2)
98 | notebook.plot_feature(feature)
99 | plt.tight_layout()
100 | plt.show()
101 |
102 | return apply
103 |
104 |
105 | def visualize_annotation(duration: Optional[float] = None):
106 | def apply(annotation: Annotation):
107 | extent = annotation.get_timeline().extent()
108 | if duration is None:
109 | notebook.crop = extent
110 | else:
111 | notebook.crop = Segment(extent.end - duration, extent.end)
112 | plt.rcParams["figure.figsize"] = (8, 2)
113 | notebook.plot_annotation(annotation)
114 | plt.tight_layout()
115 | plt.show()
116 |
117 | return apply
118 |
--------------------------------------------------------------------------------
/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/table1.png
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import pytest
4 | import torch
5 |
6 | from diart.models import SegmentationModel, EmbeddingModel
7 |
8 |
9 | class DummySegmentationModel:
10 | def to(self, device):
11 | pass
12 |
13 | def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
14 | assert waveform.ndim == 3
15 |
16 | batch_size, num_channels, num_samples = waveform.shape
17 | num_frames = random.randint(250, 500)
18 | num_speakers = random.randint(3, 5)
19 |
20 | return torch.rand(batch_size, num_frames, num_speakers)
21 |
22 |
23 | class DummyEmbeddingModel:
24 | def to(self, device):
25 | pass
26 |
27 | def __call__(self, waveform: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
28 | assert waveform.ndim == 3
29 | assert weights.ndim == 2
30 |
31 | batch_size, num_channels, num_samples = waveform.shape
32 | batch_size_weights, num_frames = weights.shape
33 |
34 | assert batch_size == batch_size_weights
35 |
36 | embedding_dim = random.randint(128, 512)
37 |
38 | return torch.randn(batch_size, embedding_dim)
39 |
40 |
41 | @pytest.fixture(scope="session")
42 | def segmentation_model() -> SegmentationModel:
43 | return SegmentationModel(DummySegmentationModel)
44 |
45 |
46 | @pytest.fixture(scope="session")
47 | def embedding_model() -> EmbeddingModel:
48 | return EmbeddingModel(DummyEmbeddingModel)
49 |
--------------------------------------------------------------------------------
/tests/data/audio/sample.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/juanmc2005/diart/392d53a1b0cd67701ecc20b683bb10614df2f7fc/tests/data/audio/sample.wav
--------------------------------------------------------------------------------
/tests/data/rttm/latency_0.5.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.675 0.533 speaker0
2 | SPEAKER sample 1 7.625 1.883 speaker0
3 | SPEAKER sample 1 9.508 1.000 speaker1
4 | SPEAKER sample 1 10.508 0.567 speaker0
5 | SPEAKER sample 1 10.625 4.133 speaker1
6 | SPEAKER sample 1 14.325 3.733 speaker0
7 | SPEAKER sample 1 18.058 3.450 speaker1
8 | SPEAKER sample 1 18.325 0.183 speaker0
9 | SPEAKER sample 1 21.508 0.017 speaker0
10 | SPEAKER sample 1 21.775 0.233 speaker1
11 | SPEAKER sample 1 22.008 6.633 speaker0
12 | SPEAKER sample 1 28.508 1.500 speaker1
13 | SPEAKER sample 1 29.958 0.050 speaker0
14 |
--------------------------------------------------------------------------------
/tests/data/rttm/latency_1.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.708 0.450 speaker0
2 | SPEAKER sample 1 7.625 1.383 speaker0
3 | SPEAKER sample 1 9.008 1.500 speaker1
4 | SPEAKER sample 1 10.008 1.067 speaker0
5 | SPEAKER sample 1 10.592 4.200 speaker1
6 | SPEAKER sample 1 14.308 3.700 speaker0
7 | SPEAKER sample 1 18.042 3.250 speaker1
8 | SPEAKER sample 1 18.508 0.033 speaker0
9 | SPEAKER sample 1 21.108 0.383 speaker0
10 | SPEAKER sample 1 21.508 0.033 speaker1
11 | SPEAKER sample 1 21.775 6.817 speaker0
12 | SPEAKER sample 1 28.008 2.000 speaker1
13 | SPEAKER sample 1 29.975 0.033 speaker0
14 |
--------------------------------------------------------------------------------
/tests/data/rttm/latency_2.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.725 0.433 speaker0
2 | SPEAKER sample 1 7.592 0.817 speaker0
3 | SPEAKER sample 1 8.475 1.617 speaker1
4 | SPEAKER sample 1 9.892 1.150 speaker0
5 | SPEAKER sample 1 10.625 4.133 speaker1
6 | SPEAKER sample 1 14.292 3.667 speaker0
7 | SPEAKER sample 1 18.008 3.533 speaker1
8 | SPEAKER sample 1 18.225 0.283 speaker0
9 | SPEAKER sample 1 21.758 6.867 speaker0
10 | SPEAKER sample 1 27.875 2.133 speaker1
11 |
--------------------------------------------------------------------------------
/tests/data/rttm/latency_3.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.725 0.433 speaker0
2 | SPEAKER sample 1 7.625 0.467 speaker0
3 | SPEAKER sample 1 8.008 2.050 speaker1
4 | SPEAKER sample 1 9.875 1.167 speaker0
5 | SPEAKER sample 1 10.592 4.167 speaker1
6 | SPEAKER sample 1 14.292 3.667 speaker0
7 | SPEAKER sample 1 17.992 3.550 speaker1
8 | SPEAKER sample 1 18.192 0.367 speaker0
9 | SPEAKER sample 1 21.758 6.833 speaker0
10 | SPEAKER sample 1 27.825 2.183 speaker1
11 |
--------------------------------------------------------------------------------
/tests/data/rttm/latency_4.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.742 0.400 speaker0
2 | SPEAKER sample 1 7.625 0.650 speaker0
3 | SPEAKER sample 1 8.092 1.950 speaker1
4 | SPEAKER sample 1 9.875 1.167 speaker0
5 | SPEAKER sample 1 10.575 4.183 speaker1
6 | SPEAKER sample 1 14.308 3.667 speaker0
7 | SPEAKER sample 1 17.992 3.550 speaker1
8 | SPEAKER sample 1 18.208 0.333 speaker0
9 | SPEAKER sample 1 21.758 6.817 speaker0
10 | SPEAKER sample 1 27.808 2.200 speaker1
11 |
--------------------------------------------------------------------------------
/tests/data/rttm/latency_5.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.742 0.383 speaker0
2 | SPEAKER sample 1 7.625 0.667 speaker0
3 | SPEAKER sample 1 8.092 1.967 speaker1
4 | SPEAKER sample 1 9.875 1.167 speaker0
5 | SPEAKER sample 1 10.558 4.200 speaker1
6 | SPEAKER sample 1 14.308 3.667 speaker0