├── diarizen
├── __init__.py
├── models
│ ├── __init__.py
│ ├── module
│ │ └── wav2vec2
│ │ │ ├── __init__.py
│ │ │ ├── utils
│ │ │ └── __init__.py
│ │ │ ├── .DS_Store
│ │ │ └── pruning_utils.py
│ └── pruning
│ │ └── model_distill_prune.py
├── pipelines
│ ├── __init__.py
│ └── utils.py
├── optimization.py
├── trainer_utils.py
├── noam_updater.py
├── logger.py
└── ckpt_utils.py
├── pyannote-audio
├── version.txt
├── .github
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── config.yml
│ │ └── bug_report.yml
│ ├── workflows
│ │ ├── pypi.yml
│ │ ├── test.yml
│ │ ├── test_cli.yml
│ │ └── doc.yml
│ └── stale.yml
├── codecov.yml
├── doc
│ ├── requirements.txt
│ ├── source
│ │ └── index.rst
│ └── gen_docs.py
├── tutorials
│ ├── assets
│ │ ├── sample.wav
│ │ ├── download-model.png
│ │ ├── pyannote.diff.PNG
│ │ ├── download-pipeline.png
│ │ ├── pyannote.review.PNG
│ │ ├── prodigy-pyannote.audio.png
│ │ └── sample.rttm
│ └── speaker_verification.ipynb
├── pyannote
│ ├── audio
│ │ ├── cli
│ │ │ ├── train_config
│ │ │ │ ├── model
│ │ │ │ │ ├── XVectorMFCC.yaml
│ │ │ │ │ ├── Pretrained.yaml
│ │ │ │ │ ├── XVectorSincNet.yaml
│ │ │ │ │ ├── DebugEmbedding.yaml
│ │ │ │ │ ├── DebugSegmentation.yaml
│ │ │ │ │ ├── PyanNet.yaml
│ │ │ │ │ └── SSeRiouSS.yaml
│ │ │ │ ├── trainer
│ │ │ │ │ ├── fast_dev_run.yaml
│ │ │ │ │ └── default.yaml
│ │ │ │ ├── optimizer
│ │ │ │ │ ├── Adan.yaml
│ │ │ │ │ ├── Adam.yaml
│ │ │ │ │ └── AdamW.yaml
│ │ │ │ ├── preprocessor
│ │ │ │ │ └── LowerTemporalResolution.yaml
│ │ │ │ ├── scheduler
│ │ │ │ │ ├── CyclicLR.yaml
│ │ │ │ │ ├── CosineAnnealingWarmRestarts.yaml
│ │ │ │ │ └── ReduceLROnPlateau.yaml
│ │ │ │ ├── config.yaml
│ │ │ │ ├── task
│ │ │ │ │ ├── SpeakerDiarization.yaml
│ │ │ │ │ ├── MultiLabelSegmentation.yaml
│ │ │ │ │ ├── VoiceActivityDetection.yaml
│ │ │ │ │ ├── SpeakerEmbedding.yaml
│ │ │ │ │ └── OverlappedSpeechDetection.yaml
│ │ │ │ ├── hydra
│ │ │ │ │ └── default.yaml
│ │ │ │ └── __init__.py
│ │ │ ├── evaluate_config
│ │ │ │ ├── config.yaml
│ │ │ │ ├── hydra
│ │ │ │ │ └── default.yaml
│ │ │ │ └── __init__.py
│ │ │ ├── config
│ │ │ │ └── hydra
│ │ │ │ │ └── default.yaml
│ │ │ ├── __init__.py
│ │ │ ├── pretrained.py
│ │ │ ├── lr_schedulers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── CosineAnnealingWarmRestarts.py
│ │ │ │ ├── CyclicLR.py
│ │ │ │ └── ReduceLROnPlateau.py
│ │ │ └── evaluate.py
│ │ ├── sample
│ │ │ ├── sample.wav
│ │ │ ├── sample.rttm
│ │ │ └── __init__.py
│ │ ├── utils
│ │ │ ├── params.py
│ │ │ ├── __init__.py
│ │ │ ├── version.py
│ │ │ ├── random.py
│ │ │ ├── multi_task.py
│ │ │ └── reproducibility.py
│ │ ├── core
│ │ │ └── __init__.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── segmentation
│ │ │ │ └── __init__.py
│ │ │ └── embedding
│ │ │ │ ├── wespeaker
│ │ │ │ ├── LICENSE.WeSpeaker
│ │ │ │ └── convert.py
│ │ │ │ └── __init__.py
│ │ ├── tasks
│ │ │ ├── embedding
│ │ │ │ └── __init__.py
│ │ │ ├── segmentation
│ │ │ │ └── __init__.py
│ │ │ └── __init__.py
│ │ ├── torchmetrics
│ │ │ ├── functional
│ │ │ │ ├── __init__.py
│ │ │ │ └── audio
│ │ │ │ │ └── __init__.py
│ │ │ ├── classification
│ │ │ │ ├── __init__.py
│ │ │ │ └── equal_error_rate.py
│ │ │ ├── __init__.py
│ │ │ └── audio
│ │ │ │ └── __init__.py
│ │ ├── augmentation
│ │ │ └── __init__.py
│ │ ├── __init__.py
│ │ └── pipelines
│ │ │ ├── __init__.py
│ │ │ └── utils
│ │ │ └── __init__.py
│ └── __init__.py
├── questions
│ ├── README.md
│ ├── from_memory.question.md
│ ├── pyannote.question.md
│ ├── streaming.question.md
│ ├── bad_performance.question.md
│ └── offline.question.md
├── .gitattributes
├── .gitmodules
├── MANIFEST.in
├── environment.yaml
├── tests
│ ├── test_import_lib.py
│ ├── test_run_notebooks.py
│ ├── utils
│ │ ├── preview.py
│ │ ├── probe_util_test.py
│ │ ├── test_permutation.py
│ │ └── test_powerset.py
│ ├── tasks
│ │ ├── test_specifications.py
│ │ └── test_reproducibility.py
│ ├── test_clustering.py
│ ├── test_sample.py
│ ├── conftest.py
│ ├── io_test.py
│ └── inference_test.py
├── faq.yml
├── .faq
│ ├── FAQ.md
│ └── suggest.md
├── requirements.txt
├── LICENSE
├── .pre-commit-config.yaml
├── .gitignore
├── setup.py
├── notebook
│ ├── sharing.ipynb
│ ├── freeze.ipynb
│ └── augmentation.ipynb
├── setup.cfg
└── FAQ.md
├── recipes
├── diar_ssl_pruning
│ ├── data
│ ├── convert_wavlm_from_hf.py
│ ├── conf
│ │ ├── dual_opt_common.toml
│ │ ├── dual_opt_common_large.toml
│ │ ├── s80_base.toml
│ │ └── s80_large.toml
│ └── README.md
└── diar_ssl
│ ├── data
│ ├── AMI_AliMeeting_AISHELL4
│ │ ├── test
│ │ │ ├── AMI
│ │ │ │ ├── all.uem
│ │ │ │ └── wav.scp
│ │ │ ├── AISHELL4
│ │ │ │ ├── all.uem
│ │ │ │ └── wav.scp
│ │ │ └── AliMeeting
│ │ │ │ ├── all.uem
│ │ │ │ └── wav.scp
│ │ └── dev
│ │ │ ├── all.uem
│ │ │ └── wav.scp
│ └── NSF
│ │ └── dev_sc
│ │ └── all.uem
│ ├── README.md
│ ├── conf
│ ├── pyannote_baseline.toml
│ ├── wavlm_frozen_conformer.toml
│ ├── fbank_conformer.toml
│ └── wavlm_updated_conformer.toml
│ └── run_stage.sh
├── .gitmodules
├── example
└── EN2002a_30s.wav
├── requirements.txt
├── LICENSE
├── MODEL_LICENSE
├── .gitignore
└── pyproject.toml
/diarizen/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/diarizen/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/diarizen/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pyannote-audio/version.txt:
--------------------------------------------------------------------------------
1 | 3.1.1
2 |
--------------------------------------------------------------------------------
/diarizen/models/module/wav2vec2/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/data:
--------------------------------------------------------------------------------
1 | ../diar_ssl/data/
--------------------------------------------------------------------------------
/diarizen/models/module/wav2vec2/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "dscore"]
2 | path = dscore
3 | url = https://github.com/nryant/dscore.git
4 |
--------------------------------------------------------------------------------
/example/EN2002a_30s.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/example/EN2002a_30s.wav
--------------------------------------------------------------------------------
/pyannote-audio/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [hbredin]
4 |
--------------------------------------------------------------------------------
/pyannote-audio/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | patch:
4 | default:
5 | enabled: false
6 |
--------------------------------------------------------------------------------
/pyannote-audio/doc/requirements.txt:
--------------------------------------------------------------------------------
1 | ipython==8.10.0
2 | recommonmark
3 | Sphinx==3.0.4
4 | sphinx_rtd_theme==0.4.3
5 |
--------------------------------------------------------------------------------
/diarizen/models/module/wav2vec2/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/diarizen/models/module/wav2vec2/.DS_Store
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/sample.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/sample.wav
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/XVectorMFCC.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.embedding.XVectorMFCC
3 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/sample/sample.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/pyannote/audio/sample/sample.wav
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/Pretrained.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.cli.pretrained
3 | checkpoint: ???
4 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/XVectorSincNet.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.embedding.XVectorSincNet
3 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/trainer/fast_dev_run.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pytorch_lightning.Trainer
3 | fast_dev_run: True
4 |
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/download-model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/download-model.png
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/pyannote.diff.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/pyannote.diff.PNG
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/download-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/download-pipeline.png
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/pyannote.review.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/pyannote.review.PNG
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/DebugEmbedding.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.embedding.debug.SimpleEmbeddingModel
3 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/evaluate_config/config.yaml:
--------------------------------------------------------------------------------
1 | model: ???
2 | protocol: ???
3 | warm_up: 0.0
4 | subset: test
5 |
6 | defaults:
7 | - hydra: default
8 |
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/prodigy-pyannote.audio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/prodigy-pyannote.audio.png
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/DebugSegmentation.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.segmentation.debug.SimpleSegmentationModel
3 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Questions
3 |
4 | Your questions should go in this directory.
5 |
6 | Question files should be named with the extension ".question.md".
7 |
--------------------------------------------------------------------------------
/pyannote-audio/.gitattributes:
--------------------------------------------------------------------------------
1 | pyannote/audio/_version.py export-subst
2 | notebooks/* linguist-documentation
3 | tutorials/* linguist-documentation
4 | versioneer.py linguist-vendored
5 |
--------------------------------------------------------------------------------
/pyannote-audio/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "tutorials/AMI-diarization-setup"]
2 | path = tutorials/AMI-diarization-setup
3 | url = https://github.com/pyannote/AMI-diarization-setup.git
4 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/optimizer/Adan.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: adan_pytorch.Adan
3 | lr: 1e-3
4 | betas: [0.1, 0.1, 0.001]
5 | weight_decay: 0.0
6 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/preprocessor/LowerTemporalResolution.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.utils.preprocessors.LowerTemporalResolution
3 | resolution: 0.1
4 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/optimizer/Adam.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: torch.optim.Adam
3 | lr: 1e-3
4 | betas: [0.9, 0.999]
5 | eps: 1e-08
6 | weight_decay: 0
7 | amsgrad: False
8 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/optimizer/AdamW.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: torch.optim.AdamW
3 | lr: 1e-3
4 | betas: [0.9, 0.999]
5 | eps: 1e-08
6 | weight_decay: 0.01
7 | amsgrad: False
8 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/scheduler/CyclicLR.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.cli.lr_schedulers.CyclicLR
3 | min_lr: 1e-8
4 | max_lr: 1e-3
5 | mode: triangular2
6 | patience: 50
7 |
--------------------------------------------------------------------------------
/pyannote-audio/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include pyannote *.py
2 | recursive-include pyannote *.yaml
3 | recursive-include pyannote *.wav
4 | recursive-include pyannote *.rttm
5 | global-exclude *.pyc
6 | global-exclude __pycache__
7 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/scheduler/CosineAnnealingWarmRestarts.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.cli.lr_schedulers.CosineAnnealingWarmRestarts
3 | min_lr: 1e-8
4 | max_lr: 1e-3
5 | patience: 1
6 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/scheduler/ReduceLROnPlateau.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.cli.lr_schedulers.ReduceLROnPlateau
3 | min_lr: 1e-8
4 | max_lr: 1e-3
5 | factor: 0.5
6 | patience: 50
7 |
--------------------------------------------------------------------------------
/pyannote-audio/environment.yaml:
--------------------------------------------------------------------------------
1 | name: pyannote-audio
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python==3.8.5
7 | - libsndfile==1.0.28
8 | - pip>=20.2
9 | - pip:
10 | - -r requirements.txt
11 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/test_import_lib.py:
--------------------------------------------------------------------------------
1 | from pyannote.audio.core.model import Model
2 |
3 |
4 | def test_import_lib():
5 | """This is a dummy test, just to check
6 | if the lib can be successfully imported.
7 | """
8 | assert Model is not None
9 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/config.yaml:
--------------------------------------------------------------------------------
1 | protocol: ???
2 |
3 | defaults:
4 | - task: SpeakerDiarization
5 | - model: PyanNet
6 | - optimizer: Adam
7 | - scheduler: CosineAnnealingWarmRestarts
8 | - trainer: default
9 | - hydra: default
10 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/task/SpeakerDiarization.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.tasks.SpeakerDiarization
3 | duration: 5.0
4 | max_speakers_per_chunk: 3
5 | max_speakers_per_frame: 2
6 | batch_size: 32
7 | num_workers: 10
8 | pin_memory: False
9 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/from_memory.question.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Can I apply pretrained pipelines on audio already loaded in memory?"
3 | alt_titles:
4 | - "Can I apply models on an audio array?"
5 | ---
6 |
7 | Yes: read [this tutorial](tutorials/applying_a_pipeline.ipynb) until the end.
8 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/task/MultiLabelSegmentation.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.tasks.MultiLabelSegmentation
3 | duration: 3.0
4 | warm_up: 0.0
5 | balance: null
6 | weight: null
7 | batch_size: 32
8 | num_workers: null
9 | pin_memory: False
10 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/task/VoiceActivityDetection.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.tasks.VoiceActivityDetection
3 | duration: 3.0
4 | warm_up: 0.0
5 | balance: null
6 | weight: null
7 | batch_size: 32
8 | num_workers: null
9 | pin_memory: False
10 |
--------------------------------------------------------------------------------
/diarizen/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz)
3 |
4 | def scp2path(scp_file):
5 | """ return path list """
6 | lines = [line.strip().split()[1] for line in open(scp_file)]
7 | return lines
8 |
--------------------------------------------------------------------------------
/pyannote-audio/faq.yml:
--------------------------------------------------------------------------------
1 | # FAQtory settings
2 |
3 | faq_url: "https://github.com/pyannote/pyannote-audio/blob/develop/FAQ.md" # Replace this with the URL to your FAQ.md!
4 |
5 | questions_path: "./questions" # Where questions should be stored
6 | output_path: "./FAQ.md" # Where FAQ.md should be generated
7 | templates_path: ".faq" # Path to templates
8 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/PyanNet.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.segmentation.PyanNet
3 | sincnet:
4 | stride: 10
5 | lstm:
6 | hidden_size: 128
7 | num_layers: 2
8 | bidirectional: true
9 | monolithic: true
10 | dropout: 0.5
11 | linear:
12 | hidden_size: 128
13 | num_layers: 2
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/task/SpeakerEmbedding.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.tasks.SupervisedRepresentationLearningWithArcFace
3 | min_duration: 2.0
4 | duration: 5.0
5 | num_classes_per_batch: 512
6 | num_chunks_per_class: 1
7 | margin: 2.0
8 | scale: 12.0
9 | num_workers: null
10 | pin_memory: False
11 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/params.py:
--------------------------------------------------------------------------------
1 | # TODO - make it depth-recursive
2 | # TODO - switch to Omegaconf maybe?
3 |
4 | from typing import Optional
5 |
6 |
7 | def merge_dict(defaults: dict, custom: Optional[dict] = None):
8 | params = dict(defaults)
9 | if custom is not None:
10 | params.update(custom)
11 | return params
12 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/model/SSeRiouSS.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.models.segmentation.SSeRiouSS
3 | wav2vec: WAVLM_BASE
4 | wav2vec_layer: -1
5 | lstm:
6 | hidden_size: 128
7 | num_layers: 4
8 | bidirectional: true
9 | monolithic: true
10 | dropout: 0.5
11 | linear:
12 | hidden_size: 128
13 | num_layers: 2
14 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/task/OverlappedSpeechDetection.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pyannote.audio.tasks.OverlappedSpeechDetection
3 | duration: 3.0
4 | warm_up: 0.0
5 | balance: null
6 | overlap:
7 | probability: 0.5
8 | snr_min: 0.0
9 | snr_max: 10.0
10 | weight: null
11 | batch_size: 32
12 | num_workers: null
13 | pin_memory: False
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | flit
3 | h5py
4 | joblib
5 | joblib
6 | jupyterlab
7 | tensorboard
8 | librosa
9 | matplotlib
10 | numpy==1.26.4
11 | onnxruntime-gpu
12 | openpyxl # for saving results to excel
13 | pandas
14 | pesq
15 | pre-commit
16 | pystoi
17 | pyyaml
18 | scipy
19 | soundfile
20 | tabulate
21 | toml
22 | torchinfo
23 | tqdm
24 | accelerate==1.6.0
25 | thop
26 |
--------------------------------------------------------------------------------
/pyannote-audio/.faq/FAQ.md:
--------------------------------------------------------------------------------
1 |
2 | # Frequently Asked Questions
3 |
4 | {%- for question in questions %}
5 | - [{{ question.title }}](#{{ question.slug }})
6 | {%- endfor %}
7 |
8 |
9 | {%- for question in questions %}
10 |
11 |
12 | ## {{ question.title }}
13 |
14 | {{ question.body }}
15 |
16 | {%- endfor %}
17 |
18 |
19 |
20 | Generated by [FAQtory](https://github.com/willmcgugan/faqtory)
21 |
--------------------------------------------------------------------------------
/pyannote-audio/doc/source/index.rst:
--------------------------------------------------------------------------------
1 | ##############
2 | pyannote.audio
3 | ##############
4 |
5 | `pyannote.audio` is an open-source Python library that provides neural building blocks for speaker diarization.
6 |
7 | Installation
8 | ============
9 |
10 | ::
11 |
12 | $ conda create -n pyannote python=3.10
13 | $ conda activate pyannote
14 | $ pip install pyannote.audio
15 |
16 |
17 | API documentation
18 | =================
19 |
20 | .. toctree::
21 | :maxdepth: 2
22 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/pyannote.question.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "How does one spell and pronounce pyannote.audio?"
3 | alt_titles:
4 | - "Why the name of the library?"
5 | - "Why the logo of the library?"
6 | ---
7 |
8 | 📝 Written in lower case: `pyannote.audio` (or `pyannote` if you are lazy). Not `PyAnnote` nor `PyAnnotate` (sic).
9 | 📢 Pronounced like the french verb `pianoter`. `pi` like in `pi`ano, not `py` like in `py`thon.
10 | 🎹 `pianoter` means to play the piano (hence the logo 🤯).
11 |
--------------------------------------------------------------------------------
/pyannote-audio/requirements.txt:
--------------------------------------------------------------------------------
1 | asteroid-filterbanks >=0.4
2 | einops >=0.6.0
3 | huggingface_hub >= 0.13.0
4 | lightning >= 2.0.1
5 | omegaconf >=2.1,<3.0
6 | pyannote.core >= 5.0.0
7 | pyannote.database >= 5.0.1
8 | pyannote.metrics >= 3.2
9 | pyannote.pipeline >= 3.0.1
10 | pytorch_metric_learning >= 2.1.0
11 | rich >= 12.0.0
12 | semver >= 3.0.0
13 | soundfile >= 0.12.1
14 | speechbrain >= 0.5.14
15 | tensorboardX >= 2.6
16 | torch >= 2.0.0
17 | torch_audiomentations >= 0.11.0
18 | torchaudio >= 2.1.1
19 | torchmetrics >= 0.11.0
20 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/test_run_notebooks.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 |
3 | import papermill as pm
4 |
5 |
6 | def test_can_run_notebooks():
7 | # Search for all notebooks in directory
8 | notebooks = glob("**/notebook/**/*.ipynb")
9 | for nb in notebooks:
10 | try:
11 | pm.execute_notebook(
12 | nb, "/dev/null", progress_bar=False, kernel_name="python"
13 | )
14 | except Exception as e:
15 | # Which notebook caused the error
16 | raise Exception(nb, e)
17 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem:
--------------------------------------------------------------------------------
1 | EN2002a 1 0.000 2142.709375
2 | EN2002b 1 0.000 1786.848000
3 | EN2002c 1 0.000 2972.256000
4 | EN2002d 1 0.000 2209.898688
5 | ES2004a 1 0.000 1049.354687
6 | ES2004b 1 0.000 2345.493375
7 | ES2004c 1 0.000 2334.368000
8 | ES2004d 1 0.000 2222.290687
9 | IS1009a 1 0.000 838.833313
10 | IS1009b 1 0.000 2052.333312
11 | IS1009c 1 0.000 1820.833312
12 | IS1009d 1 0.000 1944.500000
13 | TS3003a 1 0.000 1505.642625
14 | TS3003b 1 0.000 2210.304000
15 | TS3003c 1 0.000 2570.000000
16 | TS3003d 1 0.000 2618.200000
17 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
2 |
3 | contact_links:
4 |
5 | - name: Feature request
6 | url: https://github.com/pyannote/pyannote-audio/discussions
7 | about: Suggest an idea for this project.
8 |
9 | - name: Consulting
10 | url: https://herve.niderb.fr/consulting
11 | about: Using pyannote.audio in production? Make the most of it thanks to our consulting services.
12 |
13 | - name: Premium models
14 | url: https://forms.office.com/e/GdqwVgkZ5C
15 | about: We are considering selling premium models, extensions, or services around pyannote.audio.
16 |
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/assets/sample.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.690 0.430 speaker90
2 | SPEAKER sample 1 7.550 0.800 speaker91
3 | SPEAKER sample 1 8.320 1.700 speaker90
4 | SPEAKER sample 1 9.920 1.110 speaker91
5 | SPEAKER sample 1 10.570 4.130 speaker90
6 | SPEAKER sample 1 14.490 3.430 speaker91
7 | SPEAKER sample 1 18.050 3.440 speaker90
8 | SPEAKER sample 1 18.150 0.440 speaker91
9 | SPEAKER sample 1 21.780 6.720 speaker91
10 | SPEAKER sample 1 27.850 2.150 speaker90
11 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/sample/sample.rttm:
--------------------------------------------------------------------------------
1 | SPEAKER sample 1 6.690 0.430 speaker90
2 | SPEAKER sample 1 7.550 0.800 speaker91
3 | SPEAKER sample 1 8.320 1.700 speaker90
4 | SPEAKER sample 1 9.920 1.110 speaker91
5 | SPEAKER sample 1 10.570 4.130 speaker90
6 | SPEAKER sample 1 14.490 3.430 speaker91
7 | SPEAKER sample 1 18.050 3.440 speaker90
8 | SPEAKER sample 1 18.150 0.440 speaker91
9 | SPEAKER sample 1 21.780 6.720 speaker91
10 | SPEAKER sample 1 27.850 2.150 speaker90
11 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AISHELL4/all.uem:
--------------------------------------------------------------------------------
1 | L_R003S01C02 1 0.000 2362.945
2 | L_R003S02C02 1 0.000 2246.708
3 | L_R003S03C02 1 0.000 2268.015
4 | L_R003S04C02 1 0.000 2324.984
5 | L_R004S01C01 1 0.000 2195.433
6 | L_R004S02C01 1 0.000 2226.67
7 | L_R004S03C01 1 0.000 2227.768
8 | L_R004S06C01 1 0.000 2210.35
9 | M_R003S01C01 1 0.000 2282.21
10 | M_R003S02C01 1 0.000 2321.259
11 | M_R003S04C01 1 0.000 2256.727
12 | M_R003S05C01 1 0.000 2341.215
13 | S_R003S01C01 1 0.000 2220.529
14 | S_R003S02C01 1 0.000 2361.298
15 | S_R003S03C01 1 0.000 2279.261
16 | S_R003S04C01 1 0.000 2393.927
17 | S_R004S01C01 1 0.000 2345.908
18 | S_R004S02C01 1 0.000 2319.371
19 | S_R004S03C01 1 0.000 2327.119
20 | S_R004S04C01 1 0.000 2299.396
21 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - '*'
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v1
13 | - name: Set up Python
14 | uses: actions/setup-python@v1
15 | with:
16 | python-version: '3.x'
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install setuptools wheel twine
21 | - name: Build and publish
22 | env:
23 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
24 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
25 | run: |
26 | python setup.py sdist bdist_wheel
27 | twine upload dist/*
28 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/streaming.question.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Does pyannote support streaming speaker diarization?"
3 | alt_titles:
4 | - "Is it possible to do realtime speaker diarization?"
5 | - "Can it process online audio buffers?"
6 | ---
7 |
8 | **Short answer:** not out of the box, no.
9 |
10 | **Long answer:** [I](https://herve.niderb.fr) am looking for sponsors to add this feature. In the meantime, [`diart`](https://github.com/juanmc2005/StreamingSpeakerDiarization) is the closest you can get from a streaming `pyannote.audio`. You might also be interested in [this blog post](https://herve.niderb.fr/fastpages/2021/08/05/Streaming-voice-activity-detection-with-pyannote.html) about streaming voice activity detection based on `pyannote.audio`.
11 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Number of days of inactivity before an issue becomes stale
2 | daysUntilStale: 180
3 | # Number of days of inactivity before a stale issue is closed
4 | daysUntilClose: 30
5 | # Issues with these labels will never be considered stale
6 | exemptLabels:
7 | - pinned
8 | - security
9 | # Label to use when marking an issue as stale
10 | staleLabel: wontfix
11 | # Comment to post when marking an issue as stale. Set to `false` to disable
12 | markComment: >
13 | This issue has been automatically marked as stale because it has not had
14 | recent activity. It will be closed if no further activity occurs. Thank you
15 | for your contributions.
16 | # Comment to post when closing a stale issue. Set to `false` to disable
17 | closeComment: false
18 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/bad_performance.question.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "How can I improve performance?"
3 | alt_titles:
4 | - "Pretrained pipelines do not produce good results on my data. What can I do?"
5 | - "It does not work! Help me!"
6 | ---
7 |
8 | **Long answer:**
9 |
10 | 1. Manually annotate dozens of conversations as precisely as possible.
11 | 2. Separate them into train (80%), development (10%) and test (10%) subsets.
12 | 3. Setup the data for use with [`pyannote.database`](https://github.com/pyannote/pyannote-database#speaker-diarization).
13 | 4. Follow [this recipe](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb).
14 | 5. Enjoy.
15 |
16 | **Also:** [I am available](https://herve.niderb.fr) for contracting to help you with that.
17 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AliMeeting/all.uem:
--------------------------------------------------------------------------------
1 | R8002_M8002_MS802 1 0.000 2066.52
2 | R8002_M8003_MS803 1 0.000 2042.126
3 | R8004_M8005_MS803 1 0.000 2093.014
4 | R8004_M8006_MS805 1 0.000 1965.215
5 | R8005_M8007_MS806 1 0.000 1869.669
6 | R8005_M8008_MS806 1 0.000 1960.822
7 | R8005_M8009_MS802 1 0.000 1776.99
8 | R8006_M8012_MS803 1 0.000 1832.183
9 | R8008_M8014_MS807 1 0.000 1937.092
10 | R8008_M8015_MS808 1 0.000 1896.185
11 | R8008_M8016_MS808 1 0.000 1838.98
12 | R8008_M8017_MS808 1 0.000 1820.052
13 | R8009_M8021_MS810 1 0.000 2008.399
14 | R8009_M8022_MS810 1 0.000 2018.56
15 | R8009_M8023_MS811 1 0.000 1962.73
16 | R8009_M8024_MS811 1 0.000 1854.474
17 | R8009_M8025_MS811 1 0.000 2065.057
18 | R8009_M8026_MS812 1 0.000 1983.113
19 | R8009_M8027_MS812 1 0.000 1881.248
20 | R8009_M8028_MS812 1 0.000 1923.043
21 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/utils/preview.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from IPython.display import Audio
3 |
4 | from pyannote.audio.utils.preview import listen
5 | from pyannote.core import Segment
6 | from pyannote.database import FileFinder, get_protocol
7 |
8 |
9 | def test_file():
10 | protocol = get_protocol(
11 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()}
12 | )
13 | return next(protocol.train())
14 |
15 |
16 | def test_returns_audio_object():
17 | audio_file = test_file()
18 | ipython_audio = listen(audio_file)
19 | assert isinstance(ipython_audio, Audio)
20 |
21 |
22 | def test_can_crop():
23 | audio_file = test_file()
24 | listen(audio_file, Segment(0, 1))
25 |
26 |
27 | def test_fail_crop_too_large():
28 | with pytest.raises(ValueError):
29 | audio_file = test_file()
30 | duration = audio_file.duration
31 | listen(audio_file, Segment(0, duration * 2))
32 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/config/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | run:
4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
5 |
6 | sweep:
7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
8 | subdir: ${hydra.job.num}
9 |
10 | output_subdir: ""
11 |
12 | help:
13 | app_name: pyannote-audio-train
14 |
15 | # Help header, customize to describe your app to your users
16 | header: == ${hydra.help.app_name} ==
17 |
18 | footer: |-
19 | Powered by Hydra (https://hydra.cc)
20 | Use --hydra-help to view Hydra specific help
21 |
22 | template: |-
23 | ${hydra.help.header}
24 |
25 | pyannote-audio-train protocol={protocol_name}
26 | task={task} task.param=...
27 | model={model} model.param=...
28 | optimizer={optimizer} optimizer.param=...
29 | scheduler={scheduler} scheduler.param=...
30 |
31 | ${hydra.help.footer}
32 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/evaluate_config/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | run:
4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
5 |
6 | sweep:
7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
8 | subdir: ${hydra.job.num}
9 |
10 | output_subdir: ""
11 |
12 | help:
13 | app_name: pyannote-audio-eval
14 |
15 | # Help header, customize to describe your app to your users
16 | header: == ${hydra.help.app_name} ==
17 |
18 | footer: |-
19 | Powered by Hydra (https://hydra.cc)
20 | Use --hydra-help to view Hydra specific help
21 |
22 | template: |-
23 | ${hydra.help.header}
24 |
25 | pyannote-audio-eval registry={path_to_database.yml}
26 | protocol={protocol_name}
27 | subset={test | development | train}
28 | model={path_to_pretrained_model}
29 | warm_up={warm_up_duration_in_seconds}
30 |
31 | ${hydra.help.footer}
32 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 |
3 | run:
4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
5 |
6 | sweep:
7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ}
8 | subdir: ${hydra.job.num}
9 |
10 | output_subdir: ""
11 |
12 | help:
13 | app_name: pyannote-audio-train
14 |
15 | # Help header, customize to describe your app to your users
16 | header: == ${hydra.help.app_name} ==
17 |
18 | footer: |-
19 | Powered by Hydra (https://hydra.cc)
20 | Use --hydra-help to view Hydra specific help
21 |
22 | template: |-
23 | ${hydra.help.header}
24 |
25 | pyannote-audio-train protocol={protocol_name}
26 | +task={task} task.param=...
27 | +model={model} model.param=...
28 | optimizer={optimizer} optimizer.param=...
29 | scheduler={scheduler} scheduler.param=...
30 |
31 | ${hydra.help.footer}
32 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [develop]
6 | pull_request:
7 | branches: [develop]
8 |
9 | jobs:
10 | build:
11 | timeout-minutes: 20
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | os: [ubuntu-latest]
16 | python-version: [3.8, 3.9, "3.10"]
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install libsndfile
24 | if: matrix.os == 'ubuntu-latest'
25 | run: |
26 | sudo apt-get update
27 | sudo apt-get install libsndfile1
28 | - name: Install pyannote.audio
29 | run: |
30 | pip install -e .[dev,testing]
31 | - name: Test with pytest
32 | run: |
33 | pytest -k "not test_cli.py"
34 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/workflows/test_cli.yml:
--------------------------------------------------------------------------------
1 | name: CLI tests
2 |
3 | on:
4 | push:
5 | branches: [develop]
6 | pull_request:
7 | branches: [develop]
8 |
9 | jobs:
10 | build:
11 | timeout-minutes: 20
12 | runs-on: ${{ matrix.os }}
13 | strategy:
14 | matrix:
15 | os: [ubuntu-latest]
16 | python-version: ["3.10"]
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install libsndfile
24 | if: matrix.os == 'ubuntu-latest'
25 | run: |
26 | sudo apt-get update
27 | sudo apt-get install libsndfile1
28 | - name: Install pyannote.audio
29 | run: |
30 | pip install -e .[dev,testing,cli]
31 | - name: Test with pytest
32 | run: |
33 | pytest tests/test_cli.py
34 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/tasks/test_specifications.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyannote.database import FileFinder, get_protocol
3 |
4 | from pyannote.audio.core.model import Model
5 | from pyannote.audio.core.task import UnknownSpecificationsError
6 | from pyannote.audio.tasks import SpeakerDiarization
7 |
8 |
9 | @pytest.fixture()
10 | def protocol():
11 | return get_protocol(
12 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()}
13 | )
14 |
15 |
16 | def test_unknown_specifications_error_raised_on_non_setup_task(protocol):
17 | task = SpeakerDiarization(protocol=protocol)
18 | with pytest.raises(UnknownSpecificationsError):
19 | _ = task.specifications
20 |
21 |
22 | def test_unknown_specifications_error_raised_on_non_setup_model_task(protocol):
23 | task = SpeakerDiarization(protocol=protocol)
24 | model = Model.from_pretrained("pyannote/ci-segmentation")
25 | model.task = task
26 | with pytest.raises(UnknownSpecificationsError):
27 | _ = model.specifications
28 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/test_clustering.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pyannote.audio.pipelines.clustering import AgglomerativeClustering
4 |
5 |
6 | def test_agglomerative_clustering_num_cluster():
7 | """
8 | Make sure AgglomerativeClustering doesn't "over-merge" clusters when initial
9 | clustering already matches target num_clusters, cf
10 | https://github.com/pyannote/pyannote-audio/issues/1525
11 | """
12 |
13 | # 2 embeddings different enough
14 | embeddings = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 2.0, 1.0, 2.0]])
15 |
16 | # clustering with params that should yield 1 cluster per embedding
17 | clustering = AgglomerativeClustering().instantiate(
18 | {
19 | "method": "centroid",
20 | "min_cluster_size": 0,
21 | "threshold": 0.0,
22 | }
23 | )
24 |
25 | # request 2 clusters
26 | clusters = clustering.cluster(
27 | embeddings=embeddings, min_clusters=2, max_clusters=2, num_clusters=2
28 | )
29 | assert np.array_equal(clusters, np.array([0, 1]))
30 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | _target_: pytorch_lightning.Trainer
3 | accelerator: auto
4 | accumulate_grad_batches: 1
5 | benchmark: null # TODO: automatically set to True when using fixed duration chunks
6 | deterministic: False
7 | check_val_every_n_epoch: 1
8 | devices: auto
9 | detect_anomaly: False
10 | enable_checkpointing: True
11 | enable_model_summary: True
12 | enable_progress_bar: True
13 | fast_dev_run: False
14 | gradient_clip_val: null
15 | gradient_clip_algorithm: norm
16 | limit_predict_batches: 1.0
17 | limit_test_batches: 1.0
18 | limit_train_batches: 1.0
19 | limit_val_batches: 1.0
20 | log_every_n_steps: 50
21 | max_epochs: 1000
22 | max_steps: -1
23 | max_time: null
24 | min_epochs: 1
25 | min_steps: null
26 | num_nodes: 1
27 | num_sanity_val_steps: 2
28 | overfit_batches: 0.0
29 | precision: 32
30 | profiler: null
31 | reload_dataloaders_every_n_epochs: 0
32 | use_distributed_sampler: True # TODO: check what this does exactly
33 | strategy: auto
34 | sync_batchnorm: False
35 | val_check_interval: 1.0
36 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp:
--------------------------------------------------------------------------------
1 | IS1009a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009a.wav
2 | IS1009b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009b.wav
3 | IS1009c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009c.wav
4 | IS1009d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009d.wav
5 | ES2004a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004a.wav
6 | ES2004b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004b.wav
7 | ES2004c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004c.wav
8 | ES2004d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004d.wav
9 | TS3003a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003a.wav
10 | TS3003b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003b.wav
11 | TS3003c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003c.wav
12 | TS3003d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003d.wav
13 | EN2002a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002a.wav
14 | EN2002b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002b.wav
15 | EN2002c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002c.wav
16 | EN2002d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002d.wav
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 BUT Speech@FIT
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 CNRS
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/workflows/doc.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 | on:
3 | push:
4 | branches:
5 | - master
6 |
7 | jobs:
8 | build-and-deploy:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | max-parallel: 4
12 | matrix:
13 | python-version: ["3.9"]
14 |
15 | steps:
16 | - uses: actions/checkout@v1
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v1
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install .
25 | pip install -r doc/requirements.txt
26 | - name: Build documentation
27 | run: |
28 | make --directory=doc html
29 | touch ./doc/build/html/.nojekyll
30 | - name: Deploy
31 | env:
32 | ACTIONS_DEPLOY_KEY: ${{ secrets.ACTIONS_DEPLOY_KEY }}
33 | PUBLISH_BRANCH: gh-pages
34 | PUBLISH_DIR: ./doc/build/html
35 | SCRIPT_MODE: true
36 | run: |
37 | wget https://raw.githubusercontent.com/peaceiris/actions-gh-pages/v2/entrypoint.sh
38 | bash ./entrypoint.sh
39 |
--------------------------------------------------------------------------------
/pyannote-audio/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: '^docs/conf.py'
2 |
3 | repos:
4 | # # Clean Notebooks
5 | # - repo: https://github.com/kynan/nbstripout
6 | # rev: master
7 | # hooks:
8 | # - id: nbstripout
9 | # Format Code
10 | - repo: https://github.com/ambv/black
11 | rev: 22.3.0
12 | hooks:
13 | - id: black
14 |
15 | # Sort imports
16 | - repo: https://github.com/PyCQA/isort
17 | rev: 5.12.0
18 | hooks:
19 | - id: isort
20 | args: ["--profile", "black"]
21 |
22 | # Formatting, Whitespace, etc
23 | - repo: https://github.com/pre-commit/pre-commit-hooks
24 | rev: v2.2.3
25 | hooks:
26 | - id: trailing-whitespace
27 | - id: check-added-large-files
28 | args: ['--maxkb=1000']
29 | - id: check-ast
30 | - id: check-json
31 | - id: check-merge-conflict
32 | - id: check-xml
33 | - id: check-yaml
34 | - id: debug-statements
35 | - id: end-of-file-fixer
36 | - id: requirements-txt-fixer
37 | - id: mixed-line-ending
38 | args: ['--fix=no']
39 | - id: flake8
40 | args: ['--ignore=E203,E501,F811,E712,W503']
41 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/core/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/models/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/tasks/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/train_config/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/tasks/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/evaluate_config/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/functional/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/functional/audio/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/convert_wavlm_from_hf.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz)
3 |
4 | import os
5 | import argparse
6 |
7 | from diarizen.models.pruning.utils import convert_wavlm
8 |
9 | def run(args):
10 | hf_dir = os.path.abspath(args.hf_dir)
11 | out_dir = os.path.abspath(args.out_dir)
12 |
13 | if not os.path.isdir(hf_dir):
14 | raise FileNotFoundError(f"HuggingFace directory does not exist: {hf_dir}")
15 |
16 | os.makedirs(out_dir, exist_ok=True)
17 | convert_wavlm(hf_dir, out_dir)
18 |
19 | if __name__ == "__main__":
20 | parser = argparse.ArgumentParser(
21 | description="Convert HuggingFace WavLM model to custom format "
22 | "(e.g. from /pre-trained/HF/wavlm-base-plus)"
23 | )
24 | parser.add_argument(
25 | "hf_dir", type=str,
26 | help="Path to the HuggingFace WavLM directory containing config.json and pytorch_model.bin."
27 | )
28 | parser.add_argument(
29 | "out_dir", type=str,
30 | help="Path to output directory where the converted model will be saved."
31 | )
32 |
33 | args = parser.parse_args()
34 | run(args)
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | __import__("pkg_resources").declare_namespace(__name__)
24 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/README.md:
--------------------------------------------------------------------------------
1 | # DiariZen EEND Module
2 | This directory contains scripts for DiariZen EEND module training and global inference for speaker diarization.
3 |
4 |
5 | ## Results (collar=0s)
6 | | System | Features | AMI | AISHELL-4 | AliMeeting |
7 | |:------------|:----------------:|:------:|:------------:|:------------:|
8 | | [Pyannote v3.1](https://github.com/pyannote/pyannote-audio) | SincNet | 22.4 | 12.2 | 24.4 |
9 | | DiariZen | Fbank | 19.7 | 12.5 | 21.0 |
10 | | | WavLM-frozen | 17.0 | 11.7 | 19.9 |
11 | | | WavLM-updated | **15.4** | **11.7** | **17.6** |
12 |
13 |
14 | ## Citation
15 | If you found this work helpful, please consider citing:
16 | J. Han, F. Landini, J. Rohdin, A. Silnova, M. Diez, and L. Burget, [Leveraging Self-Supervised Learning for Speaker Diarization](https://arxiv.org/pdf/2409.09408), in Proc. ICASSP, 2025.
17 | ```
18 | @inproceedings{han2025leveraging,
19 | title={Leveraging self-supervised learning for speaker diarization},
20 | author={Han, Jiangyu and Landini, Federico and Rohdin, Johan and Silnova, Anna and Diez, Mireia and Burget, Luk{\'a}{\v{s}}},
21 | booktitle={Proc. ICASSP},
22 | year={2025}
23 | }
24 |
25 | ```
26 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from .pretrained import pretrained
24 |
25 | __all__ = [
26 | "pretrained",
27 | ]
28 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .mix import MixSpeakerDiarization
25 |
26 | __all__ = ["MixSpeakerDiarization"]
27 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/models/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from .PyanNet import PyanNet
24 | from .SSeRiouSS import SSeRiouSS
25 |
26 | __all__ = ["PyanNet", "SSeRiouSS"]
27 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .equal_error_rate import EqualErrorRate
25 |
26 | __all__ = [
27 | "EqualErrorRate",
28 | ]
29 |
--------------------------------------------------------------------------------
/pyannote-audio/questions/offline.question.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: "Can I use gated models (and pipelines) offline?"
3 | alt_titles:
4 | - "Why does one need to authenticate to access the pretrained models?"
5 | - "Can I use pyannote.audio pretrained pipelines without the Hugginface token?"
6 | - "How can I solve the permission issue?"
7 | ---
8 |
9 | **Short answer**: yes, see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines.
10 |
11 | **Long answer**: gating models and pipelines allows [me](https://herve.niderb.fr) to know a bit more about `pyannote.audio` user base and eventually help me write grant proposals to make `pyannote.audio` even better. So, please fill gating forms as precisely as possible.
12 |
13 | For instance, before gating `pyannote/speaker-diarization`, I had no idea that so many people were relying on it in production. Hint: sponsors are more than welcome! Maintaining open source libraries is time consuming.
14 |
15 | That being said, this whole authentication process does not prevent you from using official `pyannote.audio` models offline (i.e. without going through the authentication process in every `docker run ...` or whatever you are using in production): see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines.
16 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/test_sample.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c) 2024- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in
13 | # all copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | def test_sample():
25 | from pyannote.audio.sample import SAMPLE_FILE
26 |
27 | assert "annotation" in SAMPLE_FILE
28 | assert "annotated" in SAMPLE_FILE
29 |
--------------------------------------------------------------------------------
/pyannote-audio/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | .env/
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *,cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 |
56 | # Sphinx documentation
57 | docs/_build/
58 |
59 | # PyBuilder
60 | target/
61 |
62 | #Ipython Notebook
63 | .ipynb_checkpoints
64 |
65 | notebooks
66 |
67 | experiments
68 | *~
69 |
70 | *.npy
71 | *.pt
72 | *events.out.tfevents*
73 | *.csv
74 |
75 | # PyCharm
76 | .idea/
77 |
78 | gh-pages
79 | gh-pages.pub
80 |
81 | *.zip
82 | .mypy_cache/
83 | .vscode/
84 |
85 | **/lightning_logs/**
86 |
87 | # Version Output
88 | pyannote/audio/version.py
89 |
90 | # vim
91 | .vim
92 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/pretrained.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from typing import Text
25 | from pyannote.audio import Model
26 |
27 |
28 | def pretrained(checkpoint: Text):
29 | return Model.from_pretrained(checkpoint, map_location=lambda storage, loc: storage)
30 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/models/embedding/wespeaker/LICENSE.WeSpeaker:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
2 | 2022 Zhengyang Chen (chenzhengyang117@gmail.com)
3 | 2023 Bing Han (hanbing97@sjtu.edu.cn)
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 |
17 | File `resnet.py` has been borrowed from WeSpeaker that is available under the Apache License, Version 2.0.
18 |
19 | The original file is available at https://github.com/wenet-e2e/wespeaker/blob/c20d765295359e681321625fbefc1a02e8794163/wespeaker/models/resnet.py
20 |
21 | Neither Shuai Wang (@wsstriving on Github) nor myself (Hervé Bredin, or @hbredin on Github) are lawyers, but we both agreed that putting this license file in this directory is enough to comply with the license. See https://github.com/pyannote/pyannote-audio/issues/1537#issuecomment-1808029836. If you know better about this potential MIT/Apache 2.0 compatibility issue, please let us know.
22 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/lr_schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .CosineAnnealingWarmRestarts import CosineAnnealingWarmRestarts
25 | from .CyclicLR import CyclicLR
26 | from .ReduceLROnPlateau import ReduceLROnPlateau
27 |
28 | __all__ = ["ReduceLROnPlateau", "CyclicLR", "CosineAnnealingWarmRestarts"]
29 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | def pytest_sessionstart(session):
25 | """
26 | Called after the Session object has been created and
27 | before performing collection and entering the run test loop.
28 | """
29 |
30 | from pyannote.database import registry
31 |
32 | registry.load_database("tests/data/database.yml")
33 |
--------------------------------------------------------------------------------
/pyannote-audio/doc/gen_docs.py:
--------------------------------------------------------------------------------
1 | """
2 | This script will generate the rst docs for the api
3 | """
4 |
5 | import os
6 | from os import path
7 |
8 | bp = breakpoint
9 |
10 |
11 | def capitalise(s):
12 | news = ""
13 | for word in s.split("_"):
14 | news += word.capitalize()
15 | return news
16 |
17 |
18 | def process_dir(level, p):
19 | md = ""
20 | basename = path.basename(p)
21 |
22 | title = capitalise(basename)
23 | md += f"{'#'*level} {title}\n\n"
24 | subdirs = os.listdir(p)
25 |
26 | for f in subdirs:
27 | m = path.join(subdir, f)
28 | if path.isdir(m):
29 | md += process_dir(level + 1, path.join(p, f))
30 | else:
31 | if "__" in f:
32 | continue
33 | module = m[3:].replace("/", ".")[:-3]
34 | md += f"""
35 | ```eval_rst
36 | .. automodule:: {module}
37 | :members:
38 |
39 | ```
40 |
41 | """
42 | return md
43 |
44 |
45 | DIR = "../pyannote/audio"
46 |
47 | for module in os.listdir(DIR):
48 | # Each folder will become and rst file
49 | # Each file/folder will have a # prepended to it
50 | # Recursively we will add another # each level
51 |
52 | # Initialise Markdown
53 | md = ""
54 |
55 | subdir = path.join(DIR, module)
56 |
57 | # Skip if not directory
58 | if not path.isdir(subdir) or "__" in module:
59 | continue
60 |
61 | md += process_dir(1, subdir)
62 | with open(f"./source/api/{module}.md", "w") as f:
63 | f.write(md)
64 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | try:
24 | from .version import __version__, git_version # noqa: F401
25 | except ImportError:
26 | pass
27 |
28 |
29 | from .core.inference import Inference
30 | from .core.io import Audio
31 | from .core.model import Model
32 | from .core.pipeline import Pipeline
33 |
34 | __all__ = ["Audio", "Model", "Inference", "Pipeline"]
35 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AISHELL4/wav.scp:
--------------------------------------------------------------------------------
1 | S_R003S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S02C01.wav
2 | L_R004S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S02C01.wav
3 | S_R003S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S04C01.wav
4 | L_R003S04C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S04C02.wav
5 | S_R004S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S01C01.wav
6 | L_R003S02C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S02C02.wav
7 | S_R003S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S03C01.wav
8 | S_R004S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S02C01.wav
9 | L_R004S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S03C01.wav
10 | L_R004S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S06C01.wav
11 | L_R004S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S01C01.wav
12 | M_R003S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S02C01.wav
13 | L_R003S03C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S03C02.wav
14 | S_R004S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S03C01.wav
15 | L_R003S01C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S01C02.wav
16 | S_R004S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S04C01.wav
17 | M_R003S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S04C01.wav
18 | M_R003S05C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S05C01.wav
19 | M_R003S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S01C01.wav
20 | S_R003S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S01C01.wav
21 |
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | MODEL LICENSE
2 | =============
3 |
4 | The pre-trained model weights released in this repository ("the Models")
5 | are licensed under the Creative Commons Attribution-NonCommercial 4.0
6 | International License (CC BY-NC 4.0).
7 |
8 | Full legal text of the license is available at:
9 | https://creativecommons.org/licenses/by-nc/4.0/legalcode
10 |
11 | Summary of terms:
12 | - Attribution (BY): You must give appropriate credit, provide a link to the license,
13 | and indicate if changes were made.
14 | - NonCommercial (NC): You may not use the Models for commercial purposes.
15 |
16 | You are free to:
17 | - Use the Models for academic research and other non-commercial purposes.
18 | - Share and redistribute the Models, provided that you include proper attribution
19 | to the authors of DiariZen and retain this license notice.
20 | - Modify or adapt the Models for non-commercial use.
21 |
22 | You must not:
23 | - Use the Models or any derivatives thereof for commercial purposes,
24 | including but not limited to products, services, or paid offerings.
25 | - Remove or alter this license notice in any copies of the Models.
26 |
27 | Reason for this restriction:
28 | The training data used for the Models includes datasets licensed under terms
29 | that forbid commercial usage and/or derivative works
30 | (e.g., CC BY-NC, CC BY-NC-ND, or Research Only licenses).
31 | Therefore, the Models cannot be used commercially.
32 |
33 | Note:
34 | The source code in this repository is licensed separately under the MIT License
35 | (see LICENSE file). This MODEL_LICENSE applies only to the released pre-trained
36 | model weights.
37 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/models/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .wespeaker import (
25 | WeSpeakerResNet34,
26 | WeSpeakerResNet152,
27 | WeSpeakerResNet221,
28 | WeSpeakerResNet293,
29 | )
30 | from .xvector import XVectorMFCC, XVectorSincNet
31 |
32 | __all__ = [
33 | "XVectorSincNet",
34 | "XVectorMFCC",
35 | "WeSpeakerResNet34",
36 | "WeSpeakerResNet152",
37 | "WeSpeakerResNet221",
38 | "WeSpeakerResNet293",
39 | ]
40 |
--------------------------------------------------------------------------------
/diarizen/optimization.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/optimization.py
3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com)
4 |
5 | from functools import partial
6 |
7 | from torch.optim import Optimizer
8 | from torch.optim.lr_scheduler import LambdaLR
9 |
10 |
11 | def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
12 | if current_step < num_warmup_steps:
13 | return float(current_step) / float(max(1.0, num_warmup_steps))
14 | return 1.0
15 |
16 |
17 | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
18 | lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
19 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
20 |
21 |
22 | def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
23 | if current_step < num_warmup_steps:
24 | return float(current_step) / float(max(1, num_warmup_steps))
25 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
26 |
27 |
28 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
29 | lr_lambda = partial(
30 | _get_linear_schedule_with_warmup_lr_lambda,
31 | num_warmup_steps=num_warmup_steps,
32 | num_training_steps=num_training_steps,
33 | )
34 | return LambdaLR(optimizer, lr_lambda, last_epoch)
35 |
--------------------------------------------------------------------------------
/pyannote-audio/.faq/suggest.md:
--------------------------------------------------------------------------------
1 | Thank you for your issue.
2 |
3 | {%- if questions -%}
4 | {% if questions|length == 1 %}
5 | We found the following entry in the [FAQ]({{ faq_url }}) which you may find helpful:
6 | {%- else %}
7 | We found the following entries in the [FAQ]({{ faq_url }}) which you may find helpful:
8 | {%- endif %}
9 |
10 | {% for question in questions %}
11 | - [{{ question.title }}]({{ faq_url }}#{{ question.slug }})
12 | {%- endfor %}
13 |
14 | {%- else -%}
15 | You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already.
16 | {%- endif %}
17 |
18 | Feel free to close this issue if you found an answer in the FAQ.
19 |
20 | If your issue is a feature request, please read [this](https://xyproblem.info/) first and update your request accordingly, if needed.
21 |
22 | If your issue is a bug report, please provide a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) as a link to a self-contained [Google Colab](https://colab.research.google.com/) notebook containing everthing needed to reproduce the bug:
23 | - installation
24 | - data preparation
25 | - model download
26 | - etc.
27 |
28 | Providing an MRE will increase your chance of getting an answer from the community (either maintainers or other power users).
29 |
30 | Companies relying on `pyannote.audio` in production may contact [me](https://herve.niderb.fr) via email regarding:
31 | * paid scientific consulting around speaker diarization and speech processing in general;
32 | * custom models and tailored features (via the local tech transfer office).
33 |
34 | > This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory)
35 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/dev/all.uem:
--------------------------------------------------------------------------------
1 | ES2011a 1 0.000 1113.845375
2 | ES2011b 1 0.000 1581.269375
3 | ES2011c 1 0.000 1616.064000
4 | ES2011d 1 0.000 1982.325375
5 | IB4001 1 0.000 1780.650688
6 | IB4002 1 0.000 1882.368000
7 | IB4003 1 0.000 2023.253375
8 | IB4004 1 0.000 2392.832000
9 | IB4010 1 0.000 2960.554688
10 | IB4011 1 0.000 2416.981375
11 | IS1008a 1 0.000 943.833313
12 | IS1008b 1 0.000 1768.500000
13 | IS1008c 1 0.000 1546.333312
14 | IS1008d 1 0.000 1480.833312
15 | TS3004a 1 0.000 1345.322625
16 | TS3004b 1 0.000 2246.058625
17 | TS3004c 1 0.000 2970.000000
18 | TS3004d 1 0.000 2750.800000
19 | R8001_M8004_MS801 1 0.000 1573.85
20 | R8003_M8001_MS801 1 0.000 2068
21 | R8007_M8010_MS803 1 0.000 1856.322
22 | R8007_M8011_MS806 1 0.000 1861.546
23 | R8008_M8013_MS807 1 0.000 2239.47
24 | R8009_M8018_MS809 1 0.000 1654.657
25 | R8009_M8019_MS810 1 0.000 1973.926
26 | R8009_M8020_MS810 1 0.000 1910.443
27 | 20200706_L_R001S01C01 1 0.000 2011.088
28 | 20200708_L_R002S07C01 1 0.000 1939.912
29 | 20200709_L_R002S08C01 1 0.000 2040.388
30 | 20200616_M_R001S01C01 1 0.000 1868.833
31 | 20200715_M_R002S06C01 1 0.000 1915.477
32 | 20200620_M_R002S08C01 1 0.000 2258.16
33 | 20200620_M_R002S07C01 1 0.000 2196.459
34 | 20200710_M_R002S03C01 1 0.000 1900.796
35 | 20200710_M_R002S07C01 1 0.000 1794.738
36 | 20200712_M_R002S05C01 1 0.000 1875.699
37 | 20200704_M_R002S06C01 1 0.000 2154.912
38 | 20200622_M_R002S02C01 1 0.000 2236.539
39 | 20200715_M_R002S03C01 1 0.000 1929.546
40 | 20200623_S_R001S06C01 1 0.000 2087.362
41 | 20200805_S_R001S01C01 1 0.000 2195.362
42 | 20200701_S_R001S03C01 1 0.000 2273.842
43 | 20200805_S_R001S03C01 1 0.000 2224.031
44 | 20200702_S_R001S02C01 1 0.000 2209.856
45 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2022 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from .multilabel import MultiLabelSegmentation
24 | from .overlapped_speech_detection import OverlappedSpeechDetection
25 | from .resegmentation import Resegmentation
26 | from .speaker_diarization import SpeakerDiarization
27 | from .voice_activity_detection import VoiceActivityDetection
28 |
29 | __all__ = [
30 | "VoiceActivityDetection",
31 | "OverlappedSpeechDetection",
32 | "SpeakerDiarization",
33 | "Resegmentation",
34 | "MultiLabelSegmentation",
35 | ]
36 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/conf/pyannote_baseline.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_single_opt.Trainer"
10 | [trainer.args]
11 | max_epochs = 100
12 | gradient_percentile = 90
13 | gradient_history_size = 1000
14 | save_max_score = false
15 | save_ckpt_interval = 1
16 | max_patience = 10
17 | max_num_checkpoints = 100
18 | gradient_accumulation_steps = 1
19 | validation_interval = 1
20 | freeze_wavlm = false
21 | lr_decay = false
22 | use_one_cycle_lr = false
23 |
24 | [optimizer]
25 | path = "torch.optim.AdamW"
26 | [optimizer.args]
27 | lr = 1e-3
28 |
29 | [model]
30 | path = "diarizen.models.eend.model_pyannote.Model"
31 | [model.args]
32 | chunk_size = 8
33 | selected_channel = 0
34 | max_speakers_per_chunk = 4
35 |
36 | [train_dataset]
37 | path = "dataset.DiarizationDataset"
38 | [train_dataset.args]
39 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
40 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
41 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
42 | chunk_size = 8
43 | chunk_shift = 6
44 | sample_rate = 16000
45 |
46 | [train_dataset.dataloader]
47 | batch_size = 32
48 | num_workers = 1
49 | drop_last = true
50 | pin_memory = true
51 |
52 | [validate_dataset]
53 | path = "dataset.DiarizationDataset"
54 | [validate_dataset.args]
55 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
57 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
58 | chunk_size = 8
59 | chunk_shift = 8
60 | sample_rate = 16000
61 |
62 | [validate_dataset.dataloader]
63 | batch_size = 16
64 | num_workers = 1
65 | drop_last = true
66 | pin_memory = true
67 |
68 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/pipelines/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from .diarization import SpeakerDiarizationMixin
24 | from .getter import (
25 | PipelineAugmentation,
26 | PipelineInference,
27 | PipelineModel,
28 | get_augmentation,
29 | get_devices,
30 | get_inference,
31 | get_model,
32 | )
33 | from .oracle import oracle_segmentation
34 |
35 | __all__ = [
36 | "SpeakerDiarizationMixin",
37 | "oracle_segmentation",
38 | "get_augmentation",
39 | "PipelineAugmentation",
40 | "get_devices",
41 | "get_inference",
42 | "PipelineInference",
43 | "get_model",
44 | "PipelineModel",
45 | ]
46 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AliMeeting/wav.scp:
--------------------------------------------------------------------------------
1 | R8002_M8002_MS802 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8002_M8002_MS802.wav
2 | R8002_M8003_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8002_M8003_MS803.wav
3 | R8004_M8005_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8004_M8005_MS803.wav
4 | R8004_M8006_MS805 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8004_M8006_MS805.wav
5 | R8005_M8007_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8007_MS806.wav
6 | R8005_M8008_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8008_MS806.wav
7 | R8005_M8009_MS802 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8009_MS802.wav
8 | R8006_M8012_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8006_M8012_MS803.wav
9 | R8008_M8014_MS807 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8014_MS807.wav
10 | R8008_M8015_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8015_MS808.wav
11 | R8008_M8016_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8016_MS808.wav
12 | R8008_M8017_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8017_MS808.wav
13 | R8009_M8021_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8021_MS810.wav
14 | R8009_M8022_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8022_MS810.wav
15 | R8009_M8023_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8023_MS811.wav
16 | R8009_M8024_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8024_MS811.wav
17 | R8009_M8025_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8025_MS811.wav
18 | R8009_M8026_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8026_MS812.wav
19 | R8009_M8027_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8027_MS812.wav
20 | R8009_M8028_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8028_MS812.wav
21 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/utils/probe_util_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from pyannote.audio.utils.probe import probe
5 |
6 |
7 | class Trunk(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 | self.layer1 = nn.Linear(1, 2)
11 | self.layer2 = nn.Linear(2, 3)
12 | self.layer3 = nn.Linear(3, 4)
13 |
14 | def forward(self, x):
15 | return self.layer3(self.layer2(self.layer1(x)))
16 |
17 |
18 | def test_probe_dict():
19 | trunk = Trunk()
20 | probe(trunk, {"probe1": "layer1"})
21 | out = trunk(
22 | torch.ones(
23 | 1,
24 | )
25 | )
26 | assert isinstance(out, dict)
27 | assert len(out.keys()) == 1
28 | assert isinstance(out["probe1"], torch.Tensor)
29 |
30 |
31 | def test_probe_output():
32 | trunk = Trunk()
33 | probe(trunk, {"probe1": "layer3"})
34 | out = trunk(
35 | torch.ones(
36 | 1,
37 | )
38 | )
39 | out = out["probe1"]
40 | tout = trunk.layer3(
41 | trunk.layer2(
42 | trunk.layer1(
43 | torch.ones(
44 | 1,
45 | )
46 | )
47 | )
48 | )
49 | assert torch.equal(tout, out)
50 |
51 |
52 | def test_probe_revert():
53 | trunk = Trunk()
54 | revert = probe(trunk, {"probe1": "layer3"})
55 | out = trunk(
56 | torch.ones(
57 | 1,
58 | )
59 | )
60 | assert isinstance(out, dict)
61 | revert()
62 | out = trunk(
63 | torch.ones(
64 | 1,
65 | )
66 | )
67 | assert isinstance(out, torch.Tensor)
68 |
69 |
70 | def test_probe_array():
71 | trunk = Trunk()
72 | probe(trunk, ["layer3"])
73 | out = trunk(
74 | torch.ones(
75 | 1,
76 | )
77 | )
78 | assert isinstance(out, dict)
79 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .audio.diarization_error_rate import (
25 | DiarizationErrorRate,
26 | FalseAlarmRate,
27 | MissedDetectionRate,
28 | OptimalDiarizationErrorRate,
29 | OptimalDiarizationErrorRateThreshold,
30 | OptimalFalseAlarmRate,
31 | OptimalMissedDetectionRate,
32 | OptimalSpeakerConfusionRate,
33 | SpeakerConfusionRate,
34 | )
35 |
36 | __all__ = [
37 | "DiarizationErrorRate",
38 | "FalseAlarmRate",
39 | "MissedDetectionRate",
40 | "SpeakerConfusionRate",
41 | "OptimalDiarizationErrorRate",
42 | "OptimalFalseAlarmRate",
43 | "OptimalMissedDetectionRate",
44 | "OptimalSpeakerConfusionRate",
45 | "OptimalDiarizationErrorRateThreshold",
46 | ]
47 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/audio/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from .diarization_error_rate import (
25 | DiarizationErrorRate,
26 | FalseAlarmRate,
27 | MissedDetectionRate,
28 | OptimalDiarizationErrorRate,
29 | OptimalDiarizationErrorRateThreshold,
30 | OptimalFalseAlarmRate,
31 | OptimalMissedDetectionRate,
32 | OptimalSpeakerConfusionRate,
33 | SpeakerConfusionRate,
34 | )
35 |
36 | __all__ = [
37 | "DiarizationErrorRate",
38 | "SpeakerConfusionRate",
39 | "MissedDetectionRate",
40 | "FalseAlarmRate",
41 | "OptimalDiarizationErrorRate",
42 | "OptimalSpeakerConfusionRate",
43 | "OptimalMissedDetectionRate",
44 | "OptimalFalseAlarmRate",
45 | "OptimalDiarizationErrorRateThreshold",
46 | ]
47 |
--------------------------------------------------------------------------------
/pyannote-audio/tutorials/speaker_verification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "#### Speaker verification\n",
9 | "\n",
10 | "```python\n",
11 | "import torch\n",
12 | "from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding\n",
13 | "model = PretrainedSpeakerEmbedding(\n",
14 | " \"speechbrain/spkrec-ecapa-voxceleb\",\n",
15 | " device=torch.device(\"cuda\"))\n",
16 | "\n",
17 | "from pyannote.audio import Audio\n",
18 | "from pyannote.core import Segment\n",
19 | "audio = Audio(sample_rate=16000, mono=\"downmix\")\n",
20 | "\n",
21 | "# extract embedding for a speaker speaking between t=3s and t=6s\n",
22 | "speaker1 = Segment(3., 6.)\n",
23 | "waveform1, sample_rate = audio.crop(\"audio.wav\", speaker1)\n",
24 | "embedding1 = model(waveform1[None])\n",
25 | "\n",
26 | "# extract embedding for a speaker speaking between t=7s and t=12s\n",
27 | "speaker2 = Segment(7., 12.)\n",
28 | "waveform2, sample_rate = audio.crop(\"audio.wav\", speaker2)\n",
29 | "embedding2 = model(waveform2[None])\n",
30 | "\n",
31 | "# compare embeddings using \"cosine\" distance\n",
32 | "from scipy.spatial.distance import cdist\n",
33 | "distance = cdist(embedding1, embedding2, metric=\"cosine\")\n",
34 | "```\n"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": []
41 | }
42 | ],
43 | "metadata": {
44 | "interpreter": {
45 | "hash": "41379f2c2a4eb17f5ac9a1f5014f4b793a0ead0b6469d8877f81a91eb030f53e"
46 | },
47 | "kernelspec": {
48 | "display_name": "Python 3.8.2 64-bit ('pyannote': conda)",
49 | "language": "python",
50 | "name": "python3"
51 | },
52 | "language_info": {
53 | "name": "python",
54 | "version": "3.8.2"
55 | }
56 | },
57 | "nbformat": 4,
58 | "nbformat_minor": 2
59 | }
60 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/conf/wavlm_frozen_conformer.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_single_opt.Trainer"
10 | [trainer.args]
11 | max_epochs = 100
12 | gradient_percentile = 90
13 | gradient_history_size = 1000
14 | save_max_score = false
15 | save_ckpt_interval = 1
16 | max_patience = 10
17 | max_num_checkpoints = 100
18 | gradient_accumulation_steps = 1
19 | validation_interval = 1
20 | freeze_wavlm = true
21 | lr_decay = false
22 | use_one_cycle_lr = false
23 |
24 | [optimizer]
25 | path = "torch.optim.AdamW"
26 | [optimizer.args]
27 | lr = 1e-3
28 |
29 | [model]
30 | path = "diarizen.models.eend.model_wavlm_conformer.Model"
31 | [model.args]
32 | wavlm_src = "/YOUR_PATH/WavLM-Base+.pt"
33 | wavlm_layer_num = 13
34 | wavlm_feat_dim = 768
35 | attention_in = 256
36 | ffn_hidden = 1024
37 | num_head = 4
38 | num_layer = 4
39 | dropout = 0.1
40 | chunk_size = 8
41 | use_posi = false
42 | output_activate_function = false
43 | selected_channel = 0
44 | max_speakers_per_chunk = 4
45 |
46 | [train_dataset]
47 | path = "dataset.DiarizationDataset"
48 | [train_dataset.args]
49 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
50 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
51 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
52 | chunk_size = 8
53 | chunk_shift = 6
54 | sample_rate = 16000
55 |
56 | [train_dataset.dataloader]
57 | batch_size = 32
58 | num_workers = 1
59 | drop_last = true
60 | pin_memory = true
61 |
62 | [validate_dataset]
63 | path = "dataset.DiarizationDataset"
64 | [validate_dataset.args]
65 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
66 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
67 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
68 | chunk_size = 8
69 | chunk_shift = 8
70 | sample_rate = 16000
71 |
72 | [validate_dataset.dataloader]
73 | batch_size = 8
74 | num_workers = 1
75 | drop_last = true
76 | pin_memory = true
77 |
78 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/conf/fbank_conformer.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_single_opt.Trainer"
10 | [trainer.args]
11 | max_epochs = 100
12 | gradient_percentile = 90
13 | gradient_history_size = 1000
14 | save_max_score = false
15 | save_ckpt_interval = 1
16 | max_patience = 10
17 | max_num_checkpoints = 100
18 | gradient_accumulation_steps = 1
19 | validation_interval = 1
20 | freeze_wavlm = false
21 | lr_decay = false
22 | use_one_cycle_lr = false
23 |
24 | [optimizer]
25 | path = "torch.optim.AdamW"
26 | [optimizer.args]
27 | lr = 1e-3
28 |
29 | [model]
30 | path = "diarizen.models.eend.model_fbank_conformer.Model"
31 | [model.args]
32 | n_fft = 400
33 | n_mels = 80
34 | win_length = 25 # ms
35 | hop_length = 10 # ms
36 | sample_rate = 16000
37 | attention_in = 256
38 | ffn_hidden = 1024
39 | num_head = 4
40 | num_layer = 4
41 | dropout = 0.1
42 | chunk_size = 8
43 | use_posi = false
44 | output_activate_function = false
45 | selected_channel = 0
46 | max_speakers_per_chunk = 4
47 |
48 | [train_dataset]
49 | path = "dataset.DiarizationDataset"
50 | [train_dataset.args]
51 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
52 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
53 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
54 | chunk_size = 8
55 | chunk_shift = 6
56 | sample_rate = 16000
57 |
58 | [train_dataset.dataloader]
59 | batch_size = 16
60 | num_workers = 1
61 | drop_last = true
62 | pin_memory = true
63 |
64 | [validate_dataset]
65 | path = "dataset.DiarizationDataset"
66 | [validate_dataset.args]
67 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
68 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
69 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
70 | chunk_size = 8
71 | chunk_shift = 8
72 | sample_rate = 16000
73 |
74 | [validate_dataset.dataloader]
75 | batch_size = 16
76 | num_workers = 1
77 | drop_last = true
78 | pin_memory = true
79 |
80 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/tasks/test_reproducibility.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lightning.pytorch import seed_everything
3 | from pyannote.database import FileFinder, get_protocol
4 |
5 | from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel
6 | from pyannote.audio.tasks import VoiceActivityDetection
7 |
8 |
9 | def setup_tasks(task):
10 | protocol = get_protocol(
11 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()}
12 | )
13 | vad = task(protocol, duration=0.2, batch_size=32, num_workers=4)
14 | return protocol, vad
15 |
16 |
17 | def create_dl(model, task):
18 | m = model(task=task)
19 | m.prepare_data()
20 | m.setup()
21 | return task.train_dataloader()
22 |
23 |
24 | def get_next5(dl):
25 | last5 = []
26 | it = iter(dl)
27 | for i in range(5):
28 | last5.append(next(it))
29 | return last5
30 |
31 |
32 | def test_seeding_ensures_data_loaders():
33 | "Setting a global seed for the dataloaders ensures that we get data back in the same order"
34 |
35 | seed_everything(1)
36 | protocol, vad = setup_tasks(VoiceActivityDetection)
37 | dl = create_dl(SimpleSegmentationModel, vad)
38 | last5a = get_next5(dl)
39 |
40 | seed_everything(1)
41 | protocol, vad = setup_tasks(VoiceActivityDetection)
42 | dl = create_dl(SimpleSegmentationModel, vad)
43 | last5b = get_next5(dl)
44 |
45 | for i in range(len(last5b)):
46 | assert torch.equal(last5a[i]["X"], last5b[i]["X"])
47 |
48 |
49 | def test_different_seeds():
50 | "Changing the global seed will change the order of the data that loads"
51 |
52 | protocol, vad = setup_tasks(VoiceActivityDetection)
53 | seed_everything(4)
54 | dl = create_dl(SimpleSegmentationModel, vad)
55 | last5a = get_next5(dl)
56 |
57 | protocol, vad = setup_tasks(VoiceActivityDetection)
58 | seed_everything(5)
59 | dl = create_dl(SimpleSegmentationModel, vad)
60 | last5b = get_next5(dl)
61 |
62 | for i in range(5):
63 | assert not torch.equal(last5a[i]["X"], last5b[i]["X"])
64 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/conf/wavlm_updated_conformer.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_dual_opt.Trainer"
10 | [trainer.args]
11 | max_epochs = 100
12 | gradient_percentile = 90
13 | gradient_history_size = 1000
14 | save_max_score = false
15 | save_ckpt_interval = 1
16 | max_patience = 10
17 | max_num_checkpoints = 100
18 | gradient_accumulation_steps = 1
19 | validation_interval = 1
20 | freeze_wavlm = false
21 | lr_decay = false
22 | use_one_cycle_lr = false
23 |
24 | [optimizer_small]
25 | path = "torch.optim.AdamW"
26 | [optimizer_small.args]
27 | lr = 2e-5
28 |
29 | [optimizer_big]
30 | path = "torch.optim.AdamW"
31 | [optimizer_big.args]
32 | lr = 1e-3
33 |
34 | [model]
35 | path = "diarizen.models.eend.model_wavlm_conformer.Model"
36 | [model.args]
37 | wavlm_src = "/YOUR_PATH/WavLM-Base+.pt"
38 | wavlm_layer_num = 13
39 | wavlm_feat_dim = 768
40 | attention_in = 256
41 | ffn_hidden = 1024
42 | num_head = 4
43 | num_layer = 4
44 | dropout = 0.1
45 | chunk_size = 8
46 | use_posi = false
47 | output_activate_function = false
48 | selected_channel = 0
49 | max_speakers_per_chunk = 4
50 |
51 | [train_dataset]
52 | path = "dataset.DiarizationDataset"
53 | [train_dataset.args]
54 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
55 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
56 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
57 | chunk_size = 8
58 | chunk_shift = 6
59 | sample_rate = 16000
60 |
61 | [train_dataset.dataloader]
62 | batch_size = 16
63 | num_workers = 1
64 | drop_last = true
65 | pin_memory = true
66 |
67 | [validate_dataset]
68 | path = "dataset.DiarizationDataset"
69 | [validate_dataset.args]
70 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
71 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
72 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
73 | chunk_size = 8
74 | chunk_shift = 8
75 | sample_rate = 16000
76 |
77 | [validate_dataset.dataloader]
78 | batch_size = 8
79 | num_workers = 1
80 | drop_last = true
81 | pin_memory = true
82 |
83 |
--------------------------------------------------------------------------------
/diarizen/models/module/wav2vec2/pruning_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions for pruning."""
2 |
3 | from typing import Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str):
10 | "Prune linear layer in place."
11 | # NOTE: weight: (out_features, in_features), bias: (out_features,)
12 | if dim == "input":
13 | dim = 1
14 | layer.in_features = len(index)
15 | elif dim == "output":
16 | dim = 0
17 | layer.out_features = len(index)
18 | else:
19 | raise ValueError
20 |
21 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach())
22 | if layer.bias is not None and dim == 0:
23 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach())
24 |
25 |
26 | def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str):
27 | """Prune conv1d in place."""
28 | # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,)
29 | if dim == "input":
30 | dim = 1
31 | layer.in_channels = len(index)
32 | elif dim == "output":
33 | dim = 0
34 | layer.out_channels = len(index)
35 | else:
36 | raise ValueError
37 |
38 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach())
39 | if layer.bias is not None and dim == 0:
40 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach())
41 |
42 |
43 | def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor):
44 | """Prune layer norm or group norm in place."""
45 | layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach())
46 | layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach())
47 | if isinstance(layernorm, nn.LayerNorm):
48 | layernorm.normalized_shape = (len(index),)
49 | elif isinstance(layernorm, nn.GroupNorm):
50 | layernorm.num_groups = len(index)
51 | layernorm.num_channels = len(index)
52 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/conf/dual_opt_common.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_dual_opt.Trainer"
10 | [trainer.args]
11 | debug = false
12 | max_epochs = 20
13 | gradient_percentile = 90
14 | gradient_history_size = 1000
15 | save_max_score = false
16 | save_ckpt_interval = 1
17 | max_patience = 5
18 | max_num_checkpoints = 100
19 | gradient_accumulation_steps = 1
20 | validation_interval = 1
21 | freeze_wavlm = false
22 | lr_decay = false
23 | use_one_cycle_lr = false
24 |
25 | [optimizer_small]
26 | path = "torch.optim.AdamW"
27 | [optimizer_small.args]
28 | lr = 2e-5
29 |
30 | [optimizer_big]
31 | path = "torch.optim.AdamW"
32 | [optimizer_big.args]
33 | lr = 1e-3
34 |
35 | [model]
36 | path = "diarizen.models.eend.model_wavlm_conformer.Model"
37 | [model.args]
38 | wavlm_src = "/YOUR_PATH/wavlm-base-plus-finetuned.bin"
39 | wavlm_layer_num = 13
40 | wavlm_feat_dim = 768
41 | attention_in = 256
42 | ffn_hidden = 1024
43 | num_head = 4
44 | num_layer = 4
45 | dropout = 0.1
46 | chunk_size = 8
47 | use_posi = false
48 | output_activate_function = false
49 | selected_channel = 0
50 | max_speakers_per_chunk = 4
51 |
52 | [train_dataset]
53 | path = "dataset.DiarizationDataset"
54 | [train_dataset.args]
55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
58 | chunk_size = 8
59 | chunk_shift = 6
60 | sample_rate = 16000
61 |
62 | [train_dataset.dataloader]
63 | batch_size = 16
64 | num_workers = 1
65 | drop_last = true
66 | pin_memory = true
67 |
68 | [validate_dataset]
69 | path = "dataset.DiarizationDataset"
70 | [validate_dataset.args]
71 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
72 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
73 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
74 | chunk_size = 8
75 | chunk_shift = 8
76 | sample_rate = 16000
77 |
78 | [validate_dataset.dataloader]
79 | batch_size = 16
80 | num_workers = 1
81 | drop_last = true
82 | pin_memory = true
83 |
84 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/conf/dual_opt_common_large.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_dual_opt.Trainer"
10 | [trainer.args]
11 | debug = false
12 | max_epochs = 20
13 | gradient_percentile = 90
14 | gradient_history_size = 1000
15 | save_max_score = false
16 | save_ckpt_interval = 1
17 | max_patience = 5
18 | max_num_checkpoints = 100
19 | gradient_accumulation_steps = 1
20 | validation_interval = 1
21 | freeze_wavlm = false
22 | lr_decay = false
23 | use_one_cycle_lr = false
24 |
25 | [optimizer_small]
26 | path = "torch.optim.AdamW"
27 | [optimizer_small.args]
28 | lr = 2e-5
29 |
30 | [optimizer_big]
31 | path = "torch.optim.AdamW"
32 | [optimizer_big.args]
33 | lr = 1e-3
34 |
35 | [model]
36 | path = "diarizen.models.eend.model_wavlm_conformer.Model"
37 | [model.args]
38 | wavlm_src = "/YOUR_PATH/wavlm-large-finetuned.bin"
39 | wavlm_layer_num = 25
40 | wavlm_feat_dim = 1024
41 | attention_in = 256
42 | ffn_hidden = 1024
43 | num_head = 4
44 | num_layer = 4
45 | dropout = 0.1
46 | chunk_size = 8
47 | use_posi = false
48 | output_activate_function = false
49 | selected_channel = 0
50 | max_speakers_per_chunk = 4
51 |
52 | [train_dataset]
53 | path = "dataset.DiarizationDataset"
54 | [train_dataset.args]
55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
58 | chunk_size = 8
59 | chunk_shift = 6
60 | sample_rate = 16000
61 |
62 | [train_dataset.dataloader]
63 | batch_size = 16
64 | num_workers = 1
65 | drop_last = true
66 | pin_memory = true
67 |
68 | [validate_dataset]
69 | path = "dataset.DiarizationDataset"
70 | [validate_dataset.args]
71 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
72 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
73 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
74 | chunk_size = 8
75 | chunk_shift = 8
76 | sample_rate = 16000
77 |
78 | [validate_dataset.dataloader]
79 | batch_size = 16
80 | num_workers = 1
81 | drop_last = true
82 | pin_memory = true
83 |
84 |
--------------------------------------------------------------------------------
/diarizen/trainer_utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/trainer_utils.py
3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com)
4 |
5 | import numpy as np
6 | import torch
7 | from accelerate.utils import set_seed
8 |
9 |
10 | def seed_worker(_):
11 | """Helper function to set worker seed during Dataloader initialization.
12 |
13 | In recent check-ins, we may have no longer needed this function because PyTorch has already set the worker seed
14 | for numpy and random. But there is no adverse effect to keeping this function, since the initial_seed is
15 | inner_seed + worker_ids.
16 | """
17 | worker_seed = torch.initial_seed() % 2**32
18 | set_seed(worker_seed)
19 |
20 |
21 | def has_length(dataset):
22 | """
23 | Checks if the dataset implements __len__() and it doesn't raise an error
24 | """
25 | try:
26 | return len(dataset) is not None
27 | except TypeError:
28 | # TypeError: len() of unsized object
29 | return False
30 |
31 |
32 | class TrainerState:
33 | def __init__(self, save_max_score) -> None:
34 | self.epochs_trained = 0
35 | self.steps_trained = 0
36 |
37 | self.patience = 0
38 |
39 | self.best_score = -np.inf if save_max_score else np.inf
40 | self.best_score_epoch = 0
41 |
42 | def load_state_dict(self, state_dict: dict) -> None:
43 | self.epochs_trained = state_dict["epochs_trained"]
44 | self.steps_trained = state_dict["steps_trained"]
45 |
46 | self.best_score = state_dict["best_score"]
47 | self.best_score_epoch = state_dict["best_score_epoch"]
48 |
49 | self.patience = state_dict["patience"]
50 |
51 | def state_dict(self) -> dict:
52 | return {
53 | "epochs_trained": self.epochs_trained,
54 | "steps_trained": self.steps_trained,
55 | "patience": self.patience,
56 | "best_score": self.best_score,
57 | "best_score_epoch": self.best_score_epoch,
58 | }
59 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020-2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from .segmentation.multilabel import MultiLabelSegmentation # isort:skip
24 | from .segmentation.speaker_diarization import SpeakerDiarization # isort:skip
25 | from .segmentation.voice_activity_detection import VoiceActivityDetection # isort:skip
26 | from .segmentation.overlapped_speech_detection import ( # isort:skip
27 | OverlappedSpeechDetection,
28 | )
29 | from .embedding.arcface import SupervisedRepresentationLearningWithArcFace # isort:skip
30 |
31 | # Segmentation has been renamed to SpeakerDiarization but we keep Segmentation here for backward compatibility
32 | Segmentation = SpeakerDiarization
33 |
34 | # SpeakerEmbedding is more human-friendly
35 | SpeakerEmbedding = SupervisedRepresentationLearningWithArcFace
36 |
37 | __all__ = [
38 | "SpeakerDiarization",
39 | "VoiceActivityDetection",
40 | "OverlappedSpeechDetection",
41 | "MultiLabelSegmentation",
42 | "SpeakerEmbedding",
43 | "Segmentation",
44 | ]
45 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/sample/__init__.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2024- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from pathlib import Path
25 |
26 | from pyannote.core import Annotation, Segment, Timeline
27 | from pyannote.database.util import load_rttm
28 |
29 | from pyannote.audio.core.io import Audio, AudioFile
30 |
31 |
32 | def _sample() -> AudioFile:
33 | sample_wav = Path(__file__).parent / "sample.wav"
34 | uri = "sample"
35 |
36 | audio = Audio()
37 | waveform, sample_rate = audio(sample_wav)
38 |
39 | sample_rttm = Path(__file__).parent / "sample.rttm"
40 |
41 | annotation: Annotation = load_rttm(sample_rttm)[uri]
42 | duration = audio.get_duration(sample_wav)
43 |
44 | annotated: Timeline = Timeline([Segment(0.0, duration)], uri=uri)
45 |
46 | return {
47 | "audio": sample_wav,
48 | "uri": "sample",
49 | "waveform": waveform,
50 | "sample_rate": sample_rate,
51 | "annotation": annotation,
52 | "annotated": annotated,
53 | }
54 |
55 |
56 | SAMPLE_FILE = _sample()
57 |
--------------------------------------------------------------------------------
/pyannote-audio/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from pathlib import Path
4 |
5 | from pkg_resources import VersionConflict, require
6 | from setuptools import find_packages, setup
7 |
8 | with open("README.md", mode="r", encoding="utf-8") as f:
9 | long_description = f.read()
10 |
11 | with open("requirements.txt", mode="r", encoding="utf-8") as f:
12 | requirements = f.read().splitlines()
13 |
14 | try:
15 | require("setuptools>=38.3")
16 | except VersionConflict:
17 | print("Error: version of setuptools is too old (<38.3)!")
18 | sys.exit(1)
19 |
20 |
21 | ROOT_DIR = Path(__file__).parent.resolve()
22 | # Creating the version file
23 |
24 | with open("version.txt", mode="r", encoding="utf-8") as f:
25 | version = f.read()
26 |
27 | version = version.strip()
28 | sha = "Unknown"
29 |
30 | if os.getenv("BUILD_VERSION"):
31 | version = os.getenv("BUILD_VERSION")
32 | elif sha != "Unknown":
33 | version += "+" + sha[:7]
34 | print("-- Building version " + version)
35 |
36 | version_path = ROOT_DIR / "pyannote" / "audio" / "version.py"
37 |
38 | with open(version_path, mode="w", encoding="utf-8") as f:
39 | f.write("__version__ = '{}'\n".format(version))
40 |
41 | if __name__ == "__main__":
42 | setup(
43 | name="pyannote.audio",
44 | namespace_packages=["pyannote"],
45 | version=version,
46 | packages=find_packages(),
47 | install_requires=requirements,
48 | description="Neural building blocks for speaker diarization",
49 | long_description=long_description,
50 | long_description_content_type="text/markdown",
51 | author="Hervé Bredin",
52 | author_email="herve.bredin@irit.fr",
53 | url="https://github.com/pyannote/pyannote-audio",
54 | classifiers=[
55 | "Development Status :: 4 - Beta",
56 | "Intended Audience :: Science/Research",
57 | "License :: OSI Approved :: MIT License",
58 | "Natural Language :: English",
59 | "Programming Language :: Python :: 3.8",
60 | "Programming Language :: Python :: 3.9",
61 | "Programming Language :: Python :: 3.10",
62 | "Topic :: Scientific/Engineering",
63 | ],
64 | )
65 |
--------------------------------------------------------------------------------
/diarizen/noam_updater.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Licensed under the MIT license.
3 | # Copyright 2022 Brno University of Technology (author: Federico Landini, landini@fit.vut.cz)
4 |
5 | from math import sqrt
6 |
7 | import torch.optim as optim
8 | from typing import Any, Dict
9 |
10 | class NoamOpt:
11 | "Optim wrapper that implements rate."
12 | def __init__(self, model_size: int, warmup: int, optimizer: optim) -> None:
13 | self.optimizer = optimizer
14 | self._step = 0
15 | self.warmup = warmup
16 | self.model_size = model_size
17 | self._rate = 0
18 |
19 | def state_dict(self) -> Dict[str, Any]:
20 | """Returns the state of the warmup scheduler as a :class:`dict`.
21 | It contains an entry for every variable in self.__dict__ which
22 | is not the optimizer.
23 | """
24 | return {
25 | key: value
26 | for key, value in self.__dict__.items() if key != 'optimizer'}
27 |
28 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
29 | """Loads the warmup scheduler's state.
30 | Arguments:
31 | state_dict (dict): warmup scheduler state.
32 | Should be an object returned from a call to :meth:`state_dict`.
33 | """
34 | self.__dict__.update(state_dict)
35 |
36 | def step(self) -> None:
37 | "Update parameters and rate"
38 | self._step += 1
39 | rate = self.rate()
40 | for p in self.optimizer.param_groups:
41 | p['lr'] = rate
42 | self._rate = rate
43 | self.optimizer.step()
44 |
45 | def rate(self, step: int = None) -> float:
46 | "Implement `lrate` above"
47 | if step is None:
48 | step = self._step
49 | return (
50 | self.model_size ** (-0.5) *
51 | min(step ** (-0.5), step * self.warmup ** (-1.5)))
52 |
53 | def get_rate(self) -> float:
54 | return self._rate
55 |
56 | def zero_grad(self) -> None:
57 | self.optimizer.zero_grad()
58 |
59 |
60 | def get_rate(optimizer: optim) -> float:
61 | if isinstance(optimizer, NoamOpt):
62 | return optimizer.get_rate()
63 | else:
64 | for param_group in optimizer.param_groups:
65 | return param_group['lr']
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/version.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from typing import Text
24 |
25 | from semver import VersionInfo
26 |
27 |
28 | def check_version(library: Text, theirs: Text, mine: Text, what: Text = "Pipeline"):
29 |
30 | theirs = ".".join(theirs.split(".")[:3])
31 | mine = ".".join(mine.split(".")[:3])
32 |
33 | theirs = VersionInfo.parse(theirs)
34 | mine = VersionInfo.parse(mine)
35 |
36 | if theirs.major > mine.major:
37 | print(
38 | f"{what} was trained with {library} {theirs}, yours is {mine}. "
39 | f"Bad things will probably happen unless you upgrade {library} to {theirs.major}.x."
40 | )
41 |
42 | elif theirs.major < mine.major:
43 | print(
44 | f"{what} was trained with {library} {theirs}, yours is {mine}. "
45 | f"Bad things might happen unless you revert {library} to {theirs.major}.x."
46 | )
47 |
48 | elif theirs.minor > mine.minor:
49 | print(
50 | f"{what} was trained with {library} {theirs}, yours is {mine}. "
51 | f"This should be OK but you might want to upgrade {library}."
52 | )
53 |
--------------------------------------------------------------------------------
/pyannote-audio/notebook/sharing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pyannote.database import get_protocol, FileFinder\n",
10 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
11 | " preprocessors={\"audio\": FileFinder()})"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "## Train a model"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "from pyannote.audio.tasks import VoiceActivityDetection\n",
28 | "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n",
29 | "import pytorch_lightning as pl\n",
30 | "\n",
31 | "vad = VoiceActivityDetection(protocol, duration=2., batch_size=32, num_workers=4)\n",
32 | "model = SimpleSegmentationModel(task=vad)\n",
33 | "trainer = pl.Trainer(max_epochs=1, default_root_dir='sharing/')\n",
34 | "_ = trainer.fit(model)"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "## Load a model without knowing its class"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "from pyannote.audio import Model\n",
51 | "model = Model.from_pretrained('sharing/lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt')\n",
52 | "assert isinstance(model, SimpleSegmentationModel)\n",
53 | "\n",
54 | "# checkpoint should work with a URL as well (it relies on pl_load)"
55 | ]
56 | }
57 | ],
58 | "metadata": {
59 | "kernelspec": {
60 | "display_name": "Python 3",
61 | "language": "python",
62 | "name": "python3"
63 | },
64 | "language_info": {
65 | "codemirror_mode": {
66 | "name": "ipython",
67 | "version": 3
68 | },
69 | "file_extension": ".py",
70 | "mimetype": "text/x-python",
71 | "name": "python",
72 | "nbconvert_exporter": "python",
73 | "pygments_lexer": "ipython3",
74 | "version": "3.8.5"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 4
79 | }
80 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/random.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2020 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | import os
25 | import zlib
26 | from random import Random
27 |
28 | import torch
29 |
30 |
31 | def create_rng_for_worker(model) -> Random:
32 | """Create worker-specific random number generator
33 |
34 | This makes sure that
35 | 1. training samples generation is reproducible
36 | 2. every (worker, epoch) uses a different seed
37 |
38 | Parameters
39 | ----------
40 | epoch : int
41 | Current epoch.
42 | """
43 |
44 | # create random number generator
45 | rng = Random()
46 |
47 | global_seed = os.environ.get("PL_GLOBAL_SEED", "unset")
48 | worker_info = torch.utils.data.get_worker_info()
49 |
50 | if worker_info is None:
51 | worker_id = None
52 | else:
53 | worker_id = worker_info.id
54 |
55 | seed_tuple = (
56 | global_seed,
57 | worker_id,
58 | model.local_rank,
59 | model.global_rank,
60 | model.current_epoch,
61 | )
62 | # use adler32 because python's `hash` is not deterministic.
63 | seed = zlib.adler32(str(seed_tuple).encode())
64 | rng.seed(seed)
65 |
66 | return rng
67 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/conf/s80_base.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_distill_prune.Trainer"
10 | [trainer.args]
11 | debug = false
12 | max_epochs = 30
13 | gradient_percentile = 90
14 | gradient_history_size = 1000
15 | save_max_score = false
16 | save_ckpt_interval = 1
17 | max_patience = 5
18 | max_num_checkpoints = 100
19 | gradient_accumulation_steps = 1
20 | validation_interval = 1
21 | freeze_wavlm = false
22 | lr_decay = false
23 | use_one_cycle_lr = false
24 | # prune
25 | use_reg = true
26 | target_sparsity = 0.8
27 | pre_train_epochs = 0
28 | sparsity_warmup_epochs = 5
29 |
30 | [optimizer]
31 | path = "torch.optim.AdamW"
32 | [optimizer.args]
33 | distill_lr = 2e-4
34 | reg_lr = 2e-2
35 |
36 | [model]
37 | path = "diarizen.models.pruning.model_distill_prune.Model"
38 | [model.args]
39 | teacher_ckpt = "/YOUR_PATH/wavlm-base-plus-finetuned.bin"
40 | student_ckpt = "/YOUR_PATH/wavlm-base-plus-finetuned.bin"
41 | pruning_units = "conv,head,interm"
42 | distill_layers = "0,4,8,12"
43 |
44 | [distill_loss]
45 | path = "diarizen.models.pruning.utils.DistillLoss"
46 | [distill_loss.args]
47 | l2_weight = 0
48 | l1_weight = 1
49 | cos_weight = 1
50 | cos_type = "raw"
51 |
52 | [train_dataset]
53 | path = "dataset.DiarizationDataset"
54 | [train_dataset.args]
55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
58 | chunk_size = 8
59 | chunk_shift = 6
60 | sample_rate = 16000
61 | model_num_frames = 399
62 | model_rf_duration = 0.025
63 | model_rf_step = 0.02
64 |
65 | [train_dataset.dataloader]
66 | batch_size = 16
67 | num_workers = 1
68 | drop_last = true
69 | pin_memory = true
70 |
71 | [validate_dataset]
72 | path = "dataset.DiarizationDataset"
73 | [validate_dataset.args]
74 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
75 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
76 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
77 | chunk_size = 8
78 | chunk_shift = 8
79 | sample_rate = 16000
80 | model_num_frames = 399
81 | model_rf_duration = 0.025
82 | model_rf_step = 0.02
83 |
84 | [validate_dataset.dataloader]
85 | batch_size = 16
86 | num_workers = 1
87 | drop_last = true
88 | pin_memory = true
89 |
90 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/conf/s80_large.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "exp"
3 | seed = 3407
4 |
5 | [finetune]
6 | finetune = false
7 |
8 | [trainer]
9 | path = "trainer_distill_prune.Trainer"
10 | [trainer.args]
11 | debug = false
12 | max_epochs = 30
13 | gradient_percentile = 90
14 | gradient_history_size = 1000
15 | save_max_score = false
16 | save_ckpt_interval = 1
17 | max_patience = 5
18 | max_num_checkpoints = 100
19 | gradient_accumulation_steps = 1
20 | validation_interval = 1
21 | freeze_wavlm = false
22 | lr_decay = false
23 | use_one_cycle_lr = false
24 | # prune
25 | use_reg = true
26 | target_sparsity = 0.8
27 | pre_train_epochs = 0
28 | sparsity_warmup_epochs = 5
29 |
30 | [optimizer]
31 | path = "torch.optim.AdamW"
32 | [optimizer.args]
33 | distill_lr = 2e-4
34 | reg_lr = 2e-2
35 |
36 | [model]
37 | path = "diarizen.models.pruning.model_distill_prune.Model"
38 | [model.args]
39 | teacher_ckpt = "/YOUR_PATH/wavlm-large-finetuned.bin"
40 | student_ckpt = "/YOUR_PATH/wavlm-large-finetuned.bin"
41 | pruning_units = "conv,head,interm"
42 | distill_layers = "0,8,16,24"
43 |
44 | [distill_loss]
45 | path = "diarizen.models.pruning.utils.DistillLoss"
46 | [distill_loss.args]
47 | l2_weight = 0
48 | l1_weight = 1
49 | cos_weight = 1
50 | cos_type = "raw"
51 |
52 | [train_dataset]
53 | path = "dataset.DiarizationDataset"
54 | [train_dataset.args]
55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp"
56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm"
57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem"
58 | chunk_size = 8
59 | chunk_shift = 6
60 | sample_rate = 16000
61 | model_num_frames = 399
62 | model_rf_duration = 0.025
63 | model_rf_step = 0.02
64 |
65 | [train_dataset.dataloader]
66 | batch_size = 8
67 | num_workers = 1
68 | drop_last = true
69 | pin_memory = true
70 |
71 | [validate_dataset]
72 | path = "dataset.DiarizationDataset"
73 | [validate_dataset.args]
74 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp"
75 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm"
76 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem"
77 | chunk_size = 8
78 | chunk_shift = 8
79 | sample_rate = 16000
80 | model_num_frames = 399
81 | model_rf_duration = 0.025
82 | model_rf_step = 0.02
83 |
84 | [validate_dataset.dataloader]
85 | batch_size = 16
86 | num_workers = 1
87 | drop_last = true
88 | pin_memory = true
89 |
90 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/torchmetrics/classification/equal_error_rate.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from typing import Optional
25 |
26 | import torch
27 | from pyannote.metrics.binary_classification import det_curve
28 | from torchmetrics import Metric
29 | from torchmetrics.utilities.data import dim_zero_cat
30 |
31 |
32 | class EqualErrorRate(Metric):
33 |
34 | is_differentiable: Optional[bool] = False
35 | higher_is_better: Optional[bool] = False
36 | full_state_update: bool = True
37 |
38 | def __init__(self, distances: bool = True, compute_on_cpu: bool = True, **kwargs):
39 | super().__init__(compute_on_cpu=compute_on_cpu, **kwargs)
40 | self.distances = distances
41 | self.add_state("scores", default=[], dist_reduce_fx="cat")
42 | self.add_state("y_true", default=[], dist_reduce_fx="cat")
43 |
44 | def update(self, scores: torch.Tensor, y_true: torch.Tensor) -> None:
45 | self.scores.append(scores)
46 | self.y_true.append(y_true)
47 |
48 | def compute(self) -> torch.Tensor:
49 | scores = dim_zero_cat(self.scores)
50 | y_true = dim_zero_cat(self.y_true)
51 | _, _, _, eer = det_curve(y_true.cpu(), scores.cpu(), distances=self.distances)
52 | return torch.tensor(eer)
53 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/multi_task.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from typing import Any, Callable, Tuple, Union
25 |
26 | from pyannote.audio.core.model import Specifications
27 |
28 |
29 | def map_with_specifications(
30 | specifications: Union[Specifications, Tuple[Specifications]],
31 | func: Callable,
32 | *iterables,
33 | ) -> Union[Any, Tuple[Any]]:
34 | """Compute the function using arguments from each of the iterables
35 |
36 | Returns a tuple if provided `specifications` is a tuple,
37 | otherwise returns the function return value.
38 |
39 | Parameters
40 | ----------
41 | specifications : (tuple of) Specifications
42 | Specifications or tuple of specifications
43 | func : callable
44 | Function called for each specification with
45 | `func(*iterables[i], specifications=specifications[i])`
46 | *iterables :
47 | List of iterables with same length as `specifications`.
48 |
49 | Returns
50 | -------
51 | output : (tuple of) `func` return value(s)
52 | """
53 |
54 | if isinstance(specifications, Specifications):
55 | return func(*iterables, specifications=specifications)
56 |
57 | return tuple(
58 | func(*i, specifications=s) for s, *i in zip(specifications, *iterables)
59 | )
60 |
--------------------------------------------------------------------------------
/pyannote-audio/notebook/freeze.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pyannote.database import get_protocol, FileFinder\n",
10 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
11 | " preprocessors={\"audio\": FileFinder()})"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from pyannote.audio.tasks import VoiceActivityDetection\n",
21 | "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n",
22 | "import pytorch_lightning as pl"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "vad = VoiceActivityDetection(protocol, duration=2., batch_size=16, num_workers=4)\n",
32 | "model = SimpleSegmentationModel(task=vad)\n",
33 | "trainer = pl.Trainer(max_epochs=1)\n",
34 | "_ = trainer.fit(model)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "summary = model.summarize('full')"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "model.freeze_up_to('lstm')"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "model.unfreeze_up_to('mfcc.MelSpectrogram.spectrogram')"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "model.freeze_by_name(['lstm', 'activation'])"
71 | ]
72 | }
73 | ],
74 | "metadata": {
75 | "kernelspec": {
76 | "display_name": "Python 3",
77 | "language": "python",
78 | "name": "python3"
79 | },
80 | "language_info": {
81 | "codemirror_mode": {
82 | "name": "ipython",
83 | "version": 3
84 | },
85 | "file_extension": ".py",
86 | "mimetype": "text/x-python",
87 | "name": "python",
88 | "nbconvert_exporter": "python",
89 | "pygments_lexer": "ipython3",
90 | "version": "3.8.5"
91 | }
92 | },
93 | "nbformat": 4,
94 | "nbformat_minor": 4
95 | }
96 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/lr_schedulers/CosineAnnealingWarmRestarts.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from typing import Optional
24 |
25 | from torch.optim import Optimizer
26 | from torch.optim.lr_scheduler import (
27 | CosineAnnealingWarmRestarts as _CosineAnnealingWarmRestarts,
28 | )
29 |
30 |
31 | def CosineAnnealingWarmRestarts(
32 | optimizer: Optimizer,
33 | min_lr: float = 1e-8,
34 | max_lr: float = 1e-3,
35 | patience: int = 1,
36 | num_batches_per_epoch: Optional[int] = None,
37 | **kwargs,
38 | ):
39 | """Wrapper around CosineAnnealingWarmRestarts
40 |
41 | Parameters
42 | ----------
43 | optimizer : Optimizer
44 | Optimizer
45 | min_lr : float, optional
46 | Defaults to 1e-8.
47 | max_lr : float, optional
48 | Defaults to 1e-3
49 | patience : int, optional
50 | Number of epochs per cycle. Defaults to 1.
51 | num_batches_per_epoch : int, optional
52 | Number of batches per epoch.
53 | """
54 |
55 | # initialize optimizer lr to max_lr
56 | for g in optimizer.param_groups:
57 | g["lr"] = max_lr
58 |
59 | num_steps = patience * num_batches_per_epoch
60 |
61 | return {
62 | "scheduler": _CosineAnnealingWarmRestarts(
63 | optimizer, num_steps, eta_min=min_lr, T_mult=2
64 | ),
65 | "interval": "step",
66 | }
67 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/lr_schedulers/CyclicLR.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | from typing import Optional
24 |
25 | from torch.optim import Optimizer
26 | from torch.optim.lr_scheduler import CyclicLR as _CyclicLR
27 |
28 |
29 | def CyclicLR(
30 | optimizer: Optimizer,
31 | min_lr: float = 1e-8,
32 | max_lr: float = 1e-3,
33 | mode: str = "triangular2",
34 | patience: int = 50,
35 | num_batches_per_epoch: Optional[int] = None,
36 | **kwargs,
37 | ):
38 | """Wrapper around CyclicLR learning rate scheduler
39 |
40 | Parameters
41 | ----------
42 | optimizer : Optimizer
43 | Optimizer
44 | min_lr : float, optional
45 | Defaults to 1e-8.
46 | max_lr : float, optional
47 | Defaults to 1e-3
48 | patience : int, optional
49 | Number of epochs per cycle. Defaults to 50.
50 | num_batches_per_epoch : int, optional
51 | Number of batches per epoch.
52 | mode : {"triangular", "triangular2"}, optional
53 | Defaults to "triangular2".
54 | """
55 |
56 | step_size_up = int(0.5 * patience * num_batches_per_epoch)
57 |
58 | return {
59 | "scheduler": _CyclicLR(
60 | optimizer,
61 | base_lr=min_lr,
62 | max_lr=max_lr,
63 | step_size_up=step_size_up,
64 | mode=mode,
65 | cycle_momentum=False,
66 | ),
67 | "interval": "step",
68 | }
69 |
--------------------------------------------------------------------------------
/recipes/diar_ssl_pruning/README.md:
--------------------------------------------------------------------------------
1 | # Structured Pruning of WavLM
2 | This directory contains scripts for structured pruning of [WavLM](https://arxiv.org/pdf/2110.13900) applied to speaker diarization.
3 |
4 | ## How to run
5 | - Convert WavLM format from HuggingFace to our custom format: See `convert_wavlm_from_hf.py`.
6 | - Fine-tune WavLM for diarization: See `../diar_ssl/run_stage.sh`.
7 | - Start pruning training:
8 | `bash -i run_stage.sh`.
9 |
10 |
11 | ## Results (collar=0s)
12 | | System | Sparsity | Params | MACs | Speedup | AMI | AISHELL-4 | AliMeeting | Macro |
13 | |:----------------|:----------:|:----------:|:---------:|:---------:|:------:|:------------:|:-------------:|:--------:|
14 | | Fbank | - | - | - | - | 19.7 | 12.5 | 21.0 | 17.7 |
15 | | WavLM Base+ | 0% | 94.4M | 6.9G | - | 15.6 | 11.8 | 17.7 | 15.0 |
16 | | | 80% | 18.8M | 1.1G | 4.0× | 15.7 | 12.1 | 17.9 | 15.2 |
17 | | | 90% | 9.4M | 0.6G | 5.7× | 17.2 | 12.1 | 19.2 | 16.1 |
18 | | WavLM Large | 0% | 316.6M | 17.8G | - | 14.8 | 11.3 | 16.3 | 14.1 |
19 | | | 80% | 63.3M | 3.8G | 2.6× | 15.1 | 11.3 | 15.8 | 14.1 |
20 | | | 90% | 30.6M | 1.8G | 3.5× | 15.7 | 11.2 | 17.6 | 14.8 |
21 |
22 | ## Citation
23 | If you found this work helpful, please consider citing:
24 | J. Han, F. Landini, J. Rohdin, A. Silnova, M. Diez, J. Cernocky and L. Burget, [Fine-tune Before Structured Pruning: Towards Compact and Accurate Self-Supervised Models for Speaker Diarization](https://arxiv.org/pdf/2505.24111), in Proc. INTERSPEECH, 2025.
25 | ```
26 | @article{han2025fine,
27 | title={Fine-tune Before Structured Pruning: Towards Compact and Accurate Self-Supervised Models for Speaker Diarization},
28 | author={Han, Jiangyu and Landini, Federico and Rohdin, Johan and Silnova, Anna and Diez, Mireia and Cernocky, Jan and Burget, Lukas},
29 | journal={arXiv preprint arXiv:2505.24111},
30 | year={2025}
31 | }
32 |
33 | @article{han2025efficient,
34 | title={Efficient and Generalizable Speaker Diarization via Structured Pruning of Self-Supervised Models},
35 | author={Han, Jiangyu and P{\'a}lka, Petr and Delcroix, Marc and Landini, Federico and Rohdin, Johan and Cernock{\`y}, Jan and Burget, Luk{\'a}{\v{s}}},
36 | journal={arXiv preprint arXiv:2506.18623},
37 | year={2025}
38 | }
39 | ```
40 |
41 | ## Acknowledgments
42 | We thank the authors of [DPHuBERT](https://github.com/pyf98/DPHuBERT) for open-sourcing their code.
43 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/models/embedding/wespeaker/convert.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | # Script used to convert from WeSpeaker to pyannote.audio
24 |
25 | import sys
26 | from pathlib import Path
27 |
28 | import pytorch_lightning as pl
29 | import torch
30 |
31 | import pyannote.audio.models.embedding.wespeaker as wespeaker
32 | from pyannote.audio import Model
33 | from pyannote.audio.core.task import Problem, Resolution, Specifications
34 |
35 | wespeaker_checkpoint_dir = sys.argv[1] # /path/to/wespeaker_cnceleb-resnet34-LM
36 |
37 | wespeaker_checkpoint = Path(wespeaker_checkpoint_dir) / "wespeaker.pt"
38 |
39 | depth = Path(wespeaker_checkpoint_dir).parts[-1].split("-")[-2][6:] # '34'
40 | Klass = getattr(wespeaker, f"WeSpeakerResNet{depth}") # WeSpeakerResNet34
41 |
42 | duration = 5.0 # whatever
43 | specifications = Specifications(
44 | problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=duration
45 | )
46 |
47 | state_dict = torch.load(wespeaker_checkpoint, map_location=torch.device("cpu"))
48 | state_dict.pop("projection.weight")
49 |
50 | model = Klass()
51 | model.resnet.load_state_dict(state_dict, strict=True)
52 | model.specifications = specifications
53 |
54 | checkpoint = {"state_dict": model.state_dict()}
55 | model.on_save_checkpoint(checkpoint)
56 | checkpoint["pytorch-lightning_version"] = pl.__version__
57 |
58 | pyannote_checkpoint = Path(wespeaker_checkpoint_dir) / "pytorch_model.bin"
59 | torch.save(checkpoint, pyannote_checkpoint)
60 |
61 | model = Model.from_pretrained(pyannote_checkpoint)
62 | print(model)
63 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | exp*
3 | *.log
4 | *sge*
5 | *.err
6 | *.out
7 | *_debug*
8 | .idea/
9 | .vscode/*
10 | !.vscode/launch.json
11 | generated/ # generated sphinx files
12 | _build/
13 | *.csv
14 | core.*
15 | tmp/
16 |
17 | # wavlm
18 | WavLM-Base+.pt
19 |
20 | # Ruff
21 | .ruff_cache/
22 |
23 | # Byte-compiled / optimized / DLL files
24 | __pycache__/
25 | *.py[cod]
26 | *$py.class
27 |
28 | # C extensions
29 | *.so
30 |
31 | # Distribution / packaging
32 | .Python
33 | build/
34 | develop-eggs/
35 | dist/
36 | downloads/
37 | eggs/
38 | .eggs/
39 | lib/
40 | lib64/
41 | parts/
42 | sdist/
43 | var/
44 | wheels/
45 | pip-wheel-metadata/
46 | share/python-wheels/
47 | *.egg-info/
48 | .installed.cfg
49 | *.egg
50 | MANIFEST
51 |
52 | # PyInstaller
53 | # Usually these files are written by a python script from a template
54 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
55 | *.manifest
56 | *.spec
57 |
58 | # Installer logs
59 | pip-log.txt
60 | pip-delete-this-directory.txt
61 |
62 | # Unit test / coverage reports
63 | htmlcov/
64 | .tox/
65 | .nox/
66 | .coverage
67 | .coverage.*
68 | .cache
69 | nosetests.xml
70 | coverage.xml
71 | *.cover
72 | *.py,cover
73 | .hypothesis/
74 | .pytest_cache/
75 |
76 | # Translations
77 | *.mo
78 | *.pot
79 |
80 | # Django stuff:
81 | *.log
82 | local_settings.py
83 | db.sqlite3
84 | db.sqlite3-journal
85 |
86 | # Flask stuff:
87 | instance/
88 | .webassets-cache
89 |
90 | # Scrapy stuff:
91 | .scrapy
92 |
93 | # Sphinx documentation
94 | docs/_build/
95 |
96 | # PyBuilder
97 | target/
98 |
99 | # Jupyter Notebook
100 | .ipynb_checkpoints
101 |
102 | # IPython
103 | profile_default/
104 | ipython_config.py
105 |
106 | # pyenv
107 | .python-version
108 |
109 | # pipenv
110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
113 | # install all needed dependencies.
114 | #Pipfile.lock
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/io_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from pyannote.core import Segment
4 | from torch import Tensor
5 |
6 | from pyannote.audio.core.io import Audio
7 |
8 |
9 | def test_audio_resample():
10 | "Audio is correctly resampled when it isn't the correct sample rate"
11 | test_file = "tests/data/dev00.wav"
12 | info = torchaudio.info(test_file)
13 | old_sr = info.sample_rate
14 | loader = Audio(sample_rate=old_sr // 2, mono="downmix")
15 | wav, sr = loader(test_file)
16 | assert isinstance(wav, Tensor)
17 | assert sr == old_sr // 2
18 |
19 |
20 | def test_basic_load_with_defaults():
21 | test_file = "tests/data/dev00.wav"
22 | loader = Audio(mono="downmix")
23 | wav, sr = loader(test_file)
24 | assert isinstance(wav, Tensor)
25 |
26 |
27 | def test_correct_audio_channel():
28 | "When we specify an audio channel, it is chosen correctly"
29 | waveform = torch.rand(2, 16000 * 2)
30 | loader = Audio(mono="downmix")
31 | wav, sr = loader({"waveform": waveform, "sample_rate": 16000, "channel": 1})
32 | assert torch.equal(wav, waveform[1:2])
33 | assert sr == 16000
34 |
35 |
36 | def test_can_load_with_waveform():
37 | "We can load a raw waveform"
38 | waveform = torch.rand(2, 16000 * 2)
39 | loader = Audio(mono="downmix")
40 | wav, sr = loader({"waveform": waveform, "sample_rate": 16000})
41 | assert isinstance(wav, Tensor)
42 | assert sr == 16000
43 |
44 |
45 | def test_can_crop():
46 | "Cropping works when we give a Segment"
47 | test_file = "tests/data/dev00.wav"
48 | loader = Audio(mono="downmix")
49 | segment = Segment(0.2, 0.7)
50 | wav, sr = loader.crop(test_file, segment)
51 | assert wav.shape[1] / sr == 0.5
52 |
53 |
54 | def test_can_crop_waveform():
55 | "Cropping works on raw waveforms"
56 | waveform = torch.rand(1, 16000 * 2)
57 | loader = Audio(mono="downmix")
58 | segment = Segment(0.2, 0.7)
59 | wav, sr = loader.crop({"waveform": waveform, "sample_rate": 16000}, segment)
60 | assert isinstance(wav, Tensor)
61 | assert sr == 16000
62 |
63 |
64 | # File Like Object Tests
65 | def test_can_load_from_file_like():
66 | "Load entire wav of file like"
67 | loader = Audio(mono="downmix")
68 |
69 | with open("tests/data/dev00.wav", "rb") as f:
70 | wav, sr = loader(f)
71 |
72 | assert isinstance(wav, Tensor)
73 | assert sr == 16000
74 |
75 |
76 | def test_can_crop_from_file_like():
77 | "Load cropped sections from file like objects"
78 | loader = Audio(mono="downmix")
79 |
80 | with open("tests/data/dev00.wav", "rb") as f:
81 | segment = Segment(0.2, 0.7)
82 | wav, sr = loader.crop(f, segment)
83 |
84 | assert isinstance(wav, Tensor)
85 | assert sr == 16000
86 | assert wav.shape[1] == 0.5 * 16000
87 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/lr_schedulers/ReduceLROnPlateau.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from typing import Optional, Text
25 |
26 | from torch.optim import Optimizer
27 | from torch.optim.lr_scheduler import ReduceLROnPlateau as _ReduceLROnPlateau
28 |
29 |
30 | def ReduceLROnPlateau(
31 | optimizer: Optimizer,
32 | monitor: Optional[Text] = None,
33 | direction: Optional[Text] = "min",
34 | min_lr: float = 1e-8,
35 | max_lr: float = 1e-3,
36 | factor: float = 0.5,
37 | patience: int = 50,
38 | **kwargs,
39 | ):
40 | """Wrapper around ReduceLROnPlateau learning rate scheduler
41 |
42 | Parameters
43 | ----------
44 | optimizer : Optimizer
45 | Optimizer
46 | min_lr : float, optional
47 | Defaults to 1e-8.
48 | max_lr : float, optional
49 | Defaults to 1e-3
50 | factor : float, optional
51 | Defaults to 0.5
52 | patience : int, optional
53 | Wait that many epochs with no improvement before reducing the learning rate.
54 | Defaults to 50.
55 | monitor : str, optional
56 | Value to monitor
57 | direction : {"min", "max"}, optional
58 | "min" (resp. "max") means smaller (resp. larger) is better.
59 | """
60 |
61 | # initialize optimizer lr to max_lr
62 | for g in optimizer.param_groups:
63 | g["lr"] = max_lr
64 |
65 | return {
66 | "scheduler": _ReduceLROnPlateau(
67 | optimizer,
68 | mode=direction,
69 | factor=factor,
70 | patience=patience,
71 | threshold=0.0001,
72 | threshold_mode="rel",
73 | cooldown=0,
74 | min_lr=min_lr,
75 | eps=1e-08,
76 | verbose=False,
77 | ),
78 | "interval": "epoch",
79 | "monitor": monitor,
80 | "strict": True,
81 | }
82 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # ----------------- Build System -----------------
2 | [build-system]
3 | requires = ["flit_core >=3.2,<4"]
4 | build-backend = "flit_core.buildapi"
5 |
6 | # ----------------- Metadata -----------------
7 | [project]
8 | name = "diarizen"
9 | description = """DiariZen is a speaker diarization toolkit based on AudioZen .
10 | The AudioZen is mainly maintained by Xiang Hao (Hong Kong Polytechnic University).
11 | The DiariZen is mainly maintained by Jiangyu Han (Brno University of Technology)"""
12 | authors = [
13 | { name = "Xiang Hao", email = "haoxiangsnr@gmail.com" },
14 | { name = "Jiangyu Han", email = "ihan@fit.vut.cz"}
15 | ]
16 | readme = "README.md"
17 | requires-python = ">=3.10"
18 | version = "0.0.1"
19 | classifiers = [
20 | "Programming Language :: Python :: 3.10",
21 | "Development Status :: 2 - Pre-Alpha",
22 | "License :: OSI Approved :: MIT License",
23 | "Environment :: GPU :: NVIDIA CUDA",
24 | "Operating System :: OS Independent",
25 | ]
26 | keywords = [
27 | "multiple purposes",
28 | "speaker localization/tracking",
29 | "dereverberation",
30 | "enhancement",
31 | "separation",
32 | "recognition",
33 | "diarization"
34 | ]
35 | [project.optional-dependencies]
36 | test = ["pytest", "pytest-cov"]
37 | docs = ["importlib_metadata", "sphinx-autoapi", "sphinx-rtd-theme", "myst-parser", "myst-nb"]
38 | build = ["flit", "python-semantic-release", "sphinx-autobuild"]
39 | [project.urls]
40 | Source = "https://github.com/BUTSpeechFIT/DiariZen"
41 |
42 | # ----------------- Tools Configuration -----------------
43 | [tool.semantic_release]
44 | version_toml = "pyproject.toml:project.version" # version location
45 | branch = "main" # branch to make releases
46 | changelog_file = "CHANGELOG.md" # changelog file`
47 | build_command = "flit build" # build dists
48 | upload_to_release = true # auto-create GitHub release
49 | upload_to_repository = false # don't auto-upload to PyPI
50 | remove_dist = false # don't remove dists
51 | patch_without_tag = false # patch release by default
52 | commit_author = "Jiangyu Han "
53 | commit_subject = "Release {version}"
54 | commit_message = "" # commit message
55 |
56 | [tool.ruff]
57 | # Never enforce `E501` (line length violations).
58 | ignore = ["C901", "E501", "E741", "F402", "F823"]
59 | select = ["C", "E", "F", "I", "W"]
60 | line-length = 119
61 |
62 | # Ignore import violations in all `__init__.py` files.
63 | [tool.ruff.per-file-ignores]
64 | "__init__.py" = ["E402", "F401", "F403", "F811"]
65 |
66 | [tool.ruff.isort]
67 | lines-after-imports = 2
68 | known-first-party = ["diarizen"]
69 |
70 | [tool.ruff.format]
71 | # Like Black, use double quotes for strings.
72 | quote-style = "double"
73 | # Like Black, indent with spaces, rather than tabs.
74 | indent-style = "space"
75 | # Like Black, respect magic trailing commas.
76 | skip-magic-trailing-comma = false
77 | # Like Black, automatically detect the appropriate line ending.
78 | line-ending = "auto"
79 | docstring-code-format = true
80 | docstring-code-line-length = 119
81 |
82 | [tool.ruff.lint.pydocstyle]
83 | convention = "google"
84 |
--------------------------------------------------------------------------------
/diarizen/logger.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/logger.py
3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com)
4 |
5 | import logging
6 | import os
7 | import time
8 | from pathlib import Path
9 |
10 | import toml
11 | from torch.utils.tensorboard import SummaryWriter # type: ignore
12 |
13 |
14 | class TensorboardLogger(SummaryWriter):
15 | def __init__(self, log_dir: str = "") -> None:
16 | super().__init__(log_dir=log_dir, max_queue=5, flush_secs=30)
17 |
18 | def log_config(self, config: dict) -> None:
19 | self.add_text(
20 | tag="Configuration",
21 | text_string=f" \n{toml.dumps(config)} \n",
22 | global_step=1,
23 | )
24 |
25 |
26 | def init_logging_logger(config):
27 | """Initialize logging logger with handlers.
28 |
29 | Args:
30 | log_fpath: Path to save log file.
31 |
32 | Examples:
33 | >>> # Call this function at the beginning of main file.
34 | >>> init_logger(log_fpath="log_path")
35 | >>> # Use this logger in other modules.
36 | >>> import logging
37 | >>> logger = logging.getLogger(__name__)
38 | >>> logger.info("info message")
39 | """
40 | # Parse log_fpath
41 | log_dir: Path = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["exp_id"]
42 | log_dir.mkdir(parents=True, exist_ok=True)
43 |
44 | # disable logging for libraries that use standard logging module
45 | logging.getLogger("matplotlib").setLevel(logging.WARNING)
46 | logging.getLogger("numba").setLevel(logging.ERROR)
47 |
48 | # Create logger
49 | logger = logging.getLogger()
50 |
51 | # Set the lowest level of root logger and controls logging via handlers' level
52 | logger.setLevel(logging.DEBUG)
53 |
54 | # Get log level from environment variable
55 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
56 |
57 | # Create a console handler and set level to info
58 | console_handler = logging.StreamHandler()
59 | console_handler.setLevel(level=log_level)
60 |
61 | # Create a file handler and set the logger level to debug
62 | time_now = time.strftime("%Y_%m_%d--%H_%M_%S")
63 | file_handler = logging.FileHandler(str(log_dir / f"{config['meta']['exp_id']}_{time_now}.log"))
64 | file_handler.setLevel(level=log_level)
65 |
66 | # Create formatters (file logger have more info)
67 | console_formatter = logging.Formatter(
68 | "%(asctime)s: %(message)s",
69 | datefmt="%m-%d %H:%M:%S",
70 | )
71 | file_formatter = logging.Formatter(
72 | "%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d]: %(message)s",
73 | datefmt="%m-%d %H:%M:%S",
74 | )
75 |
76 | # Add formatter to ch
77 | console_handler.setFormatter(console_formatter)
78 | file_handler.setFormatter(file_formatter)
79 |
80 | # Add ch to logger
81 | logger.addHandler(console_handler)
82 | logger.addHandler(file_handler)
83 |
84 | logger.info(f"Initialized logger with log file in {log_dir.as_posix()}.")
85 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/inference_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import pytorch_lightning as pl
4 | from pyannote.core import SlidingWindowFeature
5 | from pyannote.database import FileFinder, get_protocol
6 |
7 | from pyannote.audio import Inference, Model
8 | from pyannote.audio.core.task import Resolution
9 | from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel
10 | from pyannote.audio.tasks import VoiceActivityDetection
11 |
12 | HF_SAMPLE_MODEL_ID = "pyannote/ci-segmentation"
13 |
14 |
15 | def test_hf_download_inference():
16 | inference = Inference(HF_SAMPLE_MODEL_ID, device="cpu")
17 | assert isinstance(inference, Inference)
18 |
19 |
20 | def test_hf_download_model():
21 | model = Model.from_pretrained(HF_SAMPLE_MODEL_ID)
22 | assert isinstance(model, Model)
23 |
24 |
25 | @pytest.fixture()
26 | def trained():
27 | protocol = get_protocol(
28 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()}
29 | )
30 | vad = VoiceActivityDetection(protocol, duration=2.0, batch_size=16, num_workers=4)
31 | model = SimpleSegmentationModel(task=vad)
32 | trainer = pl.Trainer(fast_dev_run=True, accelerator="cpu")
33 | trainer.fit(model)
34 | return protocol, model
35 |
36 |
37 | @pytest.fixture()
38 | def pretrained_model():
39 | return Model.from_pretrained(HF_SAMPLE_MODEL_ID)
40 |
41 |
42 | @pytest.fixture()
43 | def dev_file():
44 | protocol = get_protocol(
45 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()}
46 | )
47 | return next(protocol.development())
48 |
49 |
50 | def test_duration_warning(trained):
51 | protocol, model = trained
52 | with pytest.warns(UserWarning):
53 | duration = model.specifications.duration
54 | new_duration = duration + 1
55 | Inference(model, duration=new_duration, step=0.1, batch_size=128)
56 |
57 |
58 | def test_step_check_warning(trained):
59 | protocol, model = trained
60 | with pytest.raises(ValueError):
61 | duration = model.specifications.duration
62 | Inference(model, step=duration + 1, batch_size=128)
63 |
64 |
65 | def test_invalid_window_fails(trained):
66 | protocol, model = trained
67 | with pytest.raises(ValueError):
68 | Inference(model, window="unknown")
69 |
70 |
71 | def test_invalid_resolution_fails(trained):
72 | protocol, model = trained
73 | with pytest.warns(UserWarning):
74 | model.specifications.resolution = Resolution.FRAME
75 | Inference(model, window="whole", batch_size=128)
76 |
77 |
78 | def test_whole_window_slide(trained):
79 | protocol, model = trained
80 | inference = Inference(model, window="whole", batch_size=128)
81 | dev_file = next(protocol.development())
82 | output = inference(dev_file)
83 | assert isinstance(output, np.ndarray)
84 |
85 |
86 | def test_on_file_path(trained):
87 | protocol, model = trained
88 | inference = Inference(model, batch_size=128)
89 | output = inference("tests/data/dev00.wav")
90 | assert isinstance(output, SlidingWindowFeature)
91 |
92 |
93 | def test_skip_aggregation(pretrained_model, dev_file):
94 | inference = Inference(pretrained_model, skip_aggregation=True)
95 | scores = inference(dev_file)
96 | assert len(scores.data.shape) == 3
97 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/utils/test_permutation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from pyannote.audio.utils.permutation import permutate
5 |
6 |
7 | def test_permutate_torch():
8 |
9 | num_frames, num_speakers = 10, 3
10 |
11 | actual_permutations = [
12 | (0, 1, 2),
13 | (0, 2, 1),
14 | (1, 0, 2),
15 | (1, 2, 0),
16 | (2, 0, 1),
17 | (2, 1, 0),
18 | ]
19 | batch_size = len(actual_permutations)
20 |
21 | y2 = torch.randn((num_frames, num_speakers))
22 | y1 = torch.zeros((batch_size, num_frames, num_speakers))
23 |
24 | for p, permutation in enumerate(actual_permutations):
25 | y1[p] = y2[:, permutation]
26 |
27 | permutated_y2, permutations = permutate(y1, y2)
28 | assert actual_permutations == permutations
29 |
30 | for p, permutation in enumerate(actual_permutations):
31 | np.testing.assert_allclose(permutated_y2[p], y2[:, permutation])
32 |
33 |
34 | def test_permutate_numpy():
35 |
36 | num_frames, num_speakers = 10, 3
37 |
38 | actual_permutations = [
39 | (0, 1, 2),
40 | (0, 2, 1),
41 | (1, 0, 2),
42 | (1, 2, 0),
43 | (2, 0, 1),
44 | (2, 1, 0),
45 | ]
46 | batch_size = len(actual_permutations)
47 |
48 | y2 = np.random.randn(num_frames, num_speakers)
49 | y1 = np.zeros((batch_size, num_frames, num_speakers))
50 |
51 | for p, permutation in enumerate(actual_permutations):
52 | y1[p] = y2[:, permutation]
53 |
54 | permutated_y2, permutations = permutate(y1, y2)
55 | assert actual_permutations == permutations
56 |
57 | for p, permutation in enumerate(actual_permutations):
58 | np.testing.assert_allclose(permutated_y2[p], y2[:, permutation])
59 |
60 |
61 | def test_permutate_less_speakers():
62 |
63 | num_frames = 10
64 |
65 | actual_permutations = [
66 | (0, 1, None),
67 | (0, None, 1),
68 | (1, 0, None),
69 | (1, None, 0),
70 | (None, 0, 1),
71 | (None, 1, 0),
72 | ]
73 | batch_size = len(actual_permutations)
74 |
75 | y2 = np.random.randn(num_frames, 2)
76 | y1 = np.zeros((batch_size, num_frames, 3))
77 |
78 | for p, permutation in enumerate(actual_permutations):
79 | for i, j in enumerate(permutation):
80 | if j is not None:
81 | y1[p, :, i] = y2[:, j]
82 |
83 | permutated_y2, permutations = permutate(y1, y2)
84 |
85 | assert permutations == actual_permutations
86 |
87 |
88 | def test_permutate_more_speakers():
89 |
90 | num_frames = 10
91 |
92 | actual_permutations = [
93 | (0, 1),
94 | (0, 2),
95 | (1, 0),
96 | (1, 2),
97 | (2, 0),
98 | (2, 1),
99 | ]
100 | batch_size = len(actual_permutations)
101 |
102 | y2 = np.random.randn(num_frames, 3)
103 | y1 = np.zeros((batch_size, num_frames, 2))
104 |
105 | for p, permutation in enumerate(actual_permutations):
106 | for i, j in enumerate(permutation):
107 | y1[p, :, i] = y2[:, j]
108 |
109 | permutated_y2, permutations = permutate(y1, y2)
110 |
111 | assert permutations == actual_permutations
112 | np.testing.assert_allclose(permutated_y2, y1)
113 |
--------------------------------------------------------------------------------
/pyannote-audio/notebook/augmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# gett a 5s excerpt of first test file\n",
10 | "from pyannote.database import get_protocol, FileFinder\n",
11 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n",
12 | " preprocessors={\"audio\": FileFinder()})\n",
13 | "\n",
14 | "from pyannote.audio.core.io import Audio\n",
15 | "audio = Audio(sample_rate=16000, mono=\"downmix\")\n",
16 | "file = next(protocol.test())\n",
17 | "\n",
18 | "from pyannote.core import Segment\n",
19 | "waveform, sample_rate = audio.crop(file, Segment(5, 10))\n",
20 | "\n",
21 | "import torch\n",
22 | "waveforms = torch.tensor(waveform)[None, :]"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "# play the excerpt\n",
32 | "from IPython.display import Audio as Play\n",
33 | "Play(waveforms.squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": null,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# define a model that simply returns the waveform\n",
43 | "from pyannote.audio.core.model import Model\n",
44 | "class Passthrough(Model):\n",
45 | " def forward(self, waveforms):\n",
46 | " return waveforms\n",
47 | " \n",
48 | "identity = Passthrough()"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# pass the waveform through this \"identity\" model\n",
58 | "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "# add one torch_audiomentations waveform transform to the model\n",
68 | "from pyannote.audio.augmentation.registry import register_augmentation\n",
69 | "from torch_audiomentations import Gain\n",
70 | "gain = Gain(\n",
71 | " min_gain_in_db=-15.0,\n",
72 | " max_gain_in_db=5.0,\n",
73 | " p=0.5)\n",
74 | "register_augmentation(gain, identity, when='input')"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "# pass the waveform through the \"augmented\" model\n",
84 | "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)"
85 | ]
86 | }
87 | ],
88 | "metadata": {
89 | "kernelspec": {
90 | "display_name": "Python 3",
91 | "language": "python",
92 | "name": "python3"
93 | },
94 | "language_info": {
95 | "codemirror_mode": {
96 | "name": "ipython",
97 | "version": 3
98 | },
99 | "file_extension": ".py",
100 | "mimetype": "text/x-python",
101 | "name": "python",
102 | "nbconvert_exporter": "python",
103 | "pygments_lexer": "ipython3",
104 | "version": "3.7.9"
105 | }
106 | },
107 | "nbformat": 4,
108 | "nbformat_minor": 4
109 | }
110 |
--------------------------------------------------------------------------------
/diarizen/models/pruning/model_distill_prune.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz)
3 | # This work is inspired by: https://github.com/pyf98/DPHuBERT
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from diarizen.models.module.wav2vec2.model import wav2vec2_model
9 |
10 | class Model(nn.Module):
11 | def __init__(
12 | self,
13 | teacher_ckpt: str,
14 | student_ckpt: str,
15 | pruning_units: str = "conv,head,interm",
16 | distill_layers: str = "0,4,8,12",
17 | ):
18 | super().__init__()
19 |
20 | self.distill_layers = [int(l) for l in distill_layers.split(",")]
21 |
22 | self.teacher_model = self.build_teacher(teacher_ckpt)
23 | self.student_model, self.student_config = self.build_student(student_ckpt, pruning_units)
24 |
25 | def build_teacher(self, teacher_ckpt):
26 | teacher_ckpt = torch.load(teacher_ckpt, map_location="cpu")
27 | teacher_model = wav2vec2_model(**teacher_ckpt["config"])
28 | teacher_result = teacher_model.load_state_dict(teacher_ckpt["state_dict"], strict=False)
29 | print(f"Load pretrained ckpt to teacher: missing {teacher_result.missing_keys}, unexpected {teacher_result.unexpected_keys}")
30 |
31 | # freeze teacher
32 | for p in teacher_model.parameters():
33 | p.requires_grad = False
34 | print("Freeze parameters of the teacher model by setting requires_grad=False")
35 | teacher_model.eval()
36 | return teacher_model
37 |
38 | def build_student(self, student_ckpt, pruning_units):
39 | student_ckpt = torch.load(student_ckpt, map_location="cpu")
40 | pruning_units = pruning_units.split(",")
41 | print(f"Pruning units: {pruning_units}")
42 | student_config = student_ckpt['config']
43 | student_config.update(
44 | dict(
45 | extractor_prune_conv_channels = "conv" in pruning_units,
46 | encoder_prune_attention_heads = "head" in pruning_units,
47 | encoder_prune_attention_layer = "attlayer" in pruning_units,
48 | encoder_prune_feed_forward_intermediate = "interm" in pruning_units,
49 | encoder_prune_feed_forward_layer = "ffnlayer" in pruning_units,
50 | )
51 | )
52 | student_model = wav2vec2_model(**student_config)
53 | student_result = student_model.load_state_dict(student_ckpt["state_dict"], strict=False)
54 | print(f"Load pretrained ckpt to student: missing {student_result.missing_keys}, unexpected {student_result.unexpected_keys}")
55 | return student_model, student_config
56 |
57 | def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
58 | self.teacher_model.eval()
59 | with torch.no_grad():
60 | teacher_hiddens, _ = self.teacher_model.extract_features(waveforms)
61 | teacher_hiddens = torch.stack(
62 | [teacher_hiddens[idx] for idx in self.distill_layers], dim=1
63 | ) # (batch, layer, time, feature)
64 |
65 | student_hiddens, _ = self.student_model.extract_features(waveforms)
66 | student_hiddens = torch.stack(
67 | [student_hiddens[idx] for idx in self.distill_layers], dim=1
68 | ) # (batch, layer, time, feature)
69 |
70 | return student_hiddens, teacher_hiddens
--------------------------------------------------------------------------------
/pyannote-audio/setup.cfg:
--------------------------------------------------------------------------------
1 | # This file is used to configure your project.
2 | # Read more about the various options under:
3 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
4 |
5 | [metadata]
6 | name = pyannote-audio
7 | description = Neural speaker diarization
8 | author = Herve Bredin
9 | author-email = herve.bredin@irit.fr
10 | license = mit
11 | long-description = file: README.md
12 | long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
13 | # Change if running only on Windows, Mac or Linux (comma-separated)
14 | platforms = Linux, Mac
15 | # Add here all kinds of additional classifiers as defined under
16 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers
17 | classifiers =
18 | Development Status :: 4 - Beta
19 | Programming Language :: Python
20 |
21 | [options]
22 | zip_safe = False
23 | packages = find:
24 | include_package_data = True
25 | # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
26 | setup_requires = pyscaffold>=3.2a0,<3.3a0
27 | # Add here dependencies of your project (semicolon/line-separated), e.g.
28 | # install_requires = numpy; scipy
29 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4
30 | python_requires = >=3.7
31 |
32 | [options.packages.find]
33 | where = .
34 | exclude =
35 | tests
36 |
37 | [options.extras_require]
38 | # Add here additional requirements for extra features, to install with:
39 | # `pip install fastaudio[PDF]` like:
40 | # PDF = ReportLab; RXP
41 | # Add here test requirements (semicolon/line-separated)
42 | testing =
43 | pytest>=6.0
44 | pytest-cov>=2.10
45 | jupyter
46 | papermill
47 | dev =
48 | pre_commit>=2.7
49 | recommonmark>=0.6
50 | black>=22.3.0
51 | cli =
52 | hydra-core >=1.1,<1.2
53 | typer >= 0.4.0,<0.5.0
54 |
55 | [options.entry_points]
56 |
57 | console_scripts =
58 | pyannote-audio-train=pyannote.audio.cli.train:train
59 | pyannote-audio-eval=pyannote.audio.cli.evaluate:evaluate
60 |
61 |
62 | [test]
63 | # py.test options when running `python setup.py test`
64 | # addopts = --verbose
65 | extras = True
66 |
67 | [tool:pytest]
68 | # Options for py.test:
69 | # Specify command line options as you would do when invoking py.test directly.
70 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
71 | # in order to write a coverage file that can be read by Jenkins.
72 | addopts =
73 | --cov pyannote --cov-report term-missing
74 | --verbose
75 | norecursedirs =
76 | dist
77 | build
78 | .tox
79 | testpaths = tests
80 |
81 | [aliases]
82 | dists = bdist_wheel
83 |
84 | [bdist_wheel]
85 | # Use this option if your package is pure-python
86 | universal = 1
87 |
88 | [build_sphinx]
89 | source_dir = doc
90 | build_dir = build/sphinx
91 |
92 | [devpi:upload]
93 | # Options for the devpi: PyPI server and packaging tool
94 | # VCS export must be deactivated since we are using setuptools-scm
95 | no-vcs = 1
96 | formats = bdist_wheel
97 |
98 | [flake8]
99 | # Some sane defaults for the code style checker flake8
100 | exclude =
101 | .tox
102 | build
103 | dist
104 | .eggs
105 | docs/conf.py
106 |
107 | [pyscaffold]
108 | # PyScaffold's parameters when the project was created.
109 | # This will be used when updating. Do not change!
110 | version = 3.2.3
111 | package = pyannote-audio
112 | extensions =
113 | markdown
114 | no_skeleton
115 | pre_commit
116 | dsproject
117 |
--------------------------------------------------------------------------------
/pyannote-audio/tests/utils/test_powerset.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | import torch
25 |
26 | from pyannote.audio.utils.powerset import Powerset
27 |
28 |
29 | def test_roundtrip():
30 | for num_classes in range(2, 5):
31 | for max_set_size in range(1, num_classes + 1):
32 | powerset = Powerset(num_classes, max_set_size)
33 |
34 | # simulate a sequence where each frame is assigned to a different powerset class
35 | one_sequence = [
36 | [0] * powerset.num_powerset_classes
37 | for _ in range(powerset.num_powerset_classes)
38 | ]
39 | for i in range(powerset.num_powerset_classes):
40 | one_sequence[i][i] = 1.0
41 |
42 | # make a batch out of this sequence and the same sequence in reverse order
43 | batch_powerset = torch.tensor([one_sequence, one_sequence[::-1]])
44 |
45 | # convert from powerset to multi-label
46 | batch_multilabel = powerset.to_multilabel(batch_powerset)
47 |
48 | # convert batch back to powerset
49 | reconstruction = powerset.to_powerset(batch_multilabel)
50 |
51 | assert torch.equal(batch_powerset, reconstruction)
52 |
53 |
54 | def test_permutate_powerset():
55 | for num_classes in range(1, 6):
56 | for max_set_size in range(1, num_classes + 1):
57 | powerset = Powerset(num_classes, max_set_size)
58 |
59 | # create (num_powerset_class, num_powerset_class)-shaped tensor, where each frame is assigned to a different powerset class
60 | # and convert it to its multi-label equivalent
61 | t1 = torch.nn.functional.one_hot(
62 | torch.arange(powerset.num_powerset_classes),
63 | powerset.num_powerset_classes,
64 | )
65 | t1_ml = powerset.to_multilabel(t1)
66 |
67 | # then permutate the powerset class in powerset space AND the multilabel equivalent in its native space
68 | # and check it has the same result.
69 | # perm = torch.randperm(num_classes)
70 | perm = tuple(torch.randperm(num_classes).tolist())
71 | t1_ml_perm = t1_ml[:, perm]
72 | perm_ps = powerset.permutation_mapping[perm]
73 | t1_ps_perm = t1[..., perm_ps]
74 | t1_ps_perm_ml = powerset.to_multilabel(t1_ps_perm)
75 |
76 | assert t1_ml_perm.equal(t1_ps_perm_ml)
77 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/dev/wav.scp:
--------------------------------------------------------------------------------
1 | ES2011a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011a.wav
2 | ES2011b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011b.wav
3 | ES2011c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011c.wav
4 | ES2011d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011d.wav
5 | IB4001 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4001.wav
6 | IB4002 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4002.wav
7 | IB4003 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4003.wav
8 | IB4004 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4004.wav
9 | IB4010 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4010.wav
10 | IB4011 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4011.wav
11 | IS1008a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008a.wav
12 | IS1008b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008b.wav
13 | IS1008c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008c.wav
14 | IS1008d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008d.wav
15 | TS3004a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004a.wav
16 | TS3004b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004b.wav
17 | TS3004c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004c.wav
18 | TS3004d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004d.wav
19 | R8001_M8004_MS801 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8001_M8004_MS801.wav
20 | R8003_M8001_MS801 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8003_M8001_MS801.wav
21 | R8007_M8010_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8007_M8010_MS803.wav
22 | R8007_M8011_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8007_M8011_MS806.wav
23 | R8008_M8013_MS807 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8008_M8013_MS807.wav
24 | R8009_M8018_MS809 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8018_MS809.wav
25 | R8009_M8019_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8019_MS810.wav
26 | R8009_M8020_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8020_MS810.wav
27 | 20200706_L_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200706_L_R001S01C01.wav
28 | 20200708_L_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200708_L_R002S07C01.wav
29 | 20200709_L_R002S08C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200709_L_R002S08C01.wav
30 | 20200616_M_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200616_M_R001S01C01.wav
31 | 20200715_M_R002S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200715_M_R002S06C01.wav
32 | 20200620_M_R002S08C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200620_M_R002S08C01.wav
33 | 20200620_M_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200620_M_R002S07C01.wav
34 | 20200710_M_R002S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200710_M_R002S03C01.wav
35 | 20200710_M_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200710_M_R002S07C01.wav
36 | 20200712_M_R002S05C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200712_M_R002S05C01.wav
37 | 20200704_M_R002S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200704_M_R002S06C01.wav
38 | 20200622_M_R002S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200622_M_R002S02C01.wav
39 | 20200715_M_R002S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200715_M_R002S03C01.wav
40 | 20200623_S_R001S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200623_S_R001S06C01.wav
41 | 20200805_S_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200805_S_R001S01C01.wav
42 | 20200701_S_R001S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200701_S_R001S03C01.wav
43 | 20200805_S_R001S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200805_S_R001S03C01.wav
44 | 20200702_S_R001S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200702_S_R001S02C01.wav
45 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/data/NSF/dev_sc/all.uem:
--------------------------------------------------------------------------------
1 | S30500101 1 0.000 380.935
2 | S30500201 1 0.000 380.935
3 | S30500211 1 0.000 380.935
4 | S30502101 1 0.000 536.593
5 | S30502201 1 0.000 536.593
6 | S30502211 1 0.000 536.593
7 | S30503101 1 0.000 396.243
8 | S30503201 1 0.000 396.243
9 | S30503211 1 0.000 396.243
10 | S30504101 1 0.000 410.198
11 | S30504201 1 0.000 410.198
12 | S30504211 1 0.000 410.198
13 | S30505101 1 0.000 341.835
14 | S30505201 1 0.000 341.835
15 | S30505211 1 0.000 341.835
16 | S30508101 1 0.000 357.795
17 | S30508201 1 0.000 357.795
18 | S30508211 1 0.000 357.795
19 | S30509101 1 0.000 371.505
20 | S30509201 1 0.000 371.505
21 | S30509211 1 0.000 371.505
22 | S30519101 1 0.000 291.910
23 | S30519201 1 0.000 291.910
24 | S30519211 1 0.000 291.910
25 | S30520101 1 0.000 392.993
26 | S30520201 1 0.000 392.993
27 | S30520211 1 0.000 392.993
28 | S30522101 1 0.000 358.823
29 | S30522201 1 0.000 358.823
30 | S30522211 1 0.000 358.823
31 | S30552101 1 0.000 318.121
32 | S30552201 1 0.000 318.121
33 | S30552211 1 0.000 318.121
34 | S30552301 1 0.000 318.121
35 | S30560101 1 0.000 298.760
36 | S30560201 1 0.000 298.760
37 | S30560211 1 0.000 298.760
38 | S30560301 1 0.000 298.760
39 | S30563101 1 0.000 376.491
40 | S30563201 1 0.000 376.491
41 | S30563211 1 0.000 376.491
42 | S30563301 1 0.000 376.491
43 | S30569101 1 0.000 320.840
44 | S30569201 1 0.000 320.840
45 | S30569211 1 0.000 320.840
46 | S30569301 1 0.000 320.840
47 | S30570101 1 0.000 358.233
48 | S30570201 1 0.000 358.233
49 | S30570211 1 0.000 358.233
50 | S30570301 1 0.000 358.233
51 | S30601101 1 0.000 362.679
52 | S30601201 1 0.000 362.679
53 | S30601211 1 0.000 362.679
54 | S30601301 1 0.000 362.679
55 | S30603101 1 0.000 414.218
56 | S30603201 1 0.000 414.218
57 | S30603211 1 0.000 414.218
58 | S30603301 1 0.000 414.218
59 | S30605101 1 0.000 348.659
60 | S30605201 1 0.000 348.659
61 | S30605211 1 0.000 348.659
62 | S30605301 1 0.000 348.659
63 | S30609101 1 0.000 367.714
64 | S30609201 1 0.000 367.714
65 | S30609211 1 0.000 367.714
66 | S30610101 1 0.000 492.053
67 | S30610201 1 0.000 492.053
68 | S30610211 1 0.000 492.053
69 | S30633101 1 0.000 373.159
70 | S30633201 1 0.000 373.159
71 | S30633211 1 0.000 373.159
72 | S30633301 1 0.000 373.159
73 | S30636101 1 0.000 404.768
74 | S30636201 1 0.000 404.768
75 | S30636211 1 0.000 404.768
76 | S30636301 1 0.000 404.768
77 | S30639101 1 0.000 362.749
78 | S30639201 1 0.000 362.749
79 | S30639211 1 0.000 362.749
80 | S30639301 1 0.000 362.749
81 | S30716101 1 0.000 251.385
82 | S30716201 1 0.000 251.385
83 | S30716211 1 0.000 251.385
84 | S30716301 1 0.000 251.385
85 | S30718101 1 0.000 356.323
86 | S30718201 1 0.000 356.323
87 | S30718211 1 0.000 356.323
88 | S30718301 1 0.000 356.323
89 | S30719101 1 0.000 406.550
90 | S30719201 1 0.000 406.550
91 | S30719211 1 0.000 406.550
92 | S30719301 1 0.000 406.550
93 | S30721201 1 0.000 367.565
94 | S30721211 1 0.000 367.565
95 | S30721301 1 0.000 367.565
96 | S30724201 1 0.000 378.140
97 | S30724211 1 0.000 378.140
98 | S30724301 1 0.000 378.140
99 | S30725201 1 0.000 391.870
100 | S30725211 1 0.000 391.870
101 | S30725301 1 0.000 391.870
102 | S30752101 1 0.000 385.781
103 | S30752201 1 0.000 385.781
104 | S30752211 1 0.000 385.781
105 | S30752301 1 0.000 385.781
106 | S30754101 1 0.000 444.688
107 | S30754201 1 0.000 444.688
108 | S30754211 1 0.000 444.688
109 | S30754301 1 0.000 444.688
110 | S30773101 1 0.000 420.119
111 | S30773201 1 0.000 420.119
112 | S30773211 1 0.000 420.119
113 | S30773301 1 0.000 420.119
114 | S30774101 1 0.000 372.440
115 | S30774201 1 0.000 372.440
116 | S30774211 1 0.000 372.440
117 | S30774301 1 0.000 372.440
118 |
--------------------------------------------------------------------------------
/pyannote-audio/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug report
2 | description: Report a bug in pyannote.audio
3 | body:
4 |
5 | - type: markdown
6 | attributes:
7 | value: |
8 | When reporting bugs, please follow the guidelines in this template. This helps identify the problem precisely and thus enables contributors to fix it faster.
9 | - Write a descriptive issue title above.
10 | - The golden rule is to **always open *one* issue for *one* bug**. If you notice several bugs and want to report them, make sure to create one new issue for each of them.
11 | - Search [open](https://github.com/pyannote/pyannote-audio/issues) and [closed](https://github.com/pyannote/pyannote-audio/issues?q=is%3Aissue+is%3Aclosed) issues to ensure it has not already been reported. If you don't find a relevant match or if you're unsure, don't hesitate to **open a new issue**. The bugsquad will handle it from there if it's a duplicate.
12 | - Please always check if your issue is reproducible in the latest version – it may already have been fixed!
13 | - If you use a custom build, please test if your issue is reproducible in official releases too.
14 |
15 | - type: textarea
16 | attributes:
17 | label: Tested versions
18 | description: |
19 | To properly fix a bug, we need to identify if the bug was recently introduced in the engine, or if it was always present.
20 | - Please specify the pyannote.audio version you found the issue in, including the **Git commit hash** if using a development build.
21 | - If you can, **please test earlier pyannote.audio versions** and, if applicable, newer versions (development branch). Mention whether the bug is reproducible or not in the versions you tested.
22 | - The aim is for us to identify whether a bug is a **regression**, i.e. an issue that didn't exist in a previous version, but was introduced later on, breaking existing functionality. For example, if a bug is reproducible in 3.2 but not in 3.0, we would like you to test intermediate 3.1 to find which version is the first one where the issue can be reproduced.
23 | placeholder: |
24 | - Reproducible in: 3.1, 3.2, and later
25 | - Not reproducible in: 3.0
26 | validations:
27 | required: true
28 |
29 | - type: input
30 | attributes:
31 | label: System information
32 | description: |
33 | - Specify the OS version, and when relevant hardware information.
34 | - For issues that are likely OS-specific and/or GPU-related, please specify the GPU model and architecture.
35 | - **Bug reports not including the required information may be closed at the maintainers' discretion.** If in doubt, always include all the requested information; it's better to include too much information than not enough information.
36 | placeholder: macOS 13.6 - pyannote.audio 3.1.1 - M1 Pro
37 | validations:
38 | required: true
39 |
40 | - type: textarea
41 | attributes:
42 | label: Issue description
43 | description: |
44 | Describe your issue briefly. What doesn't work, and how do you expect it to work instead?
45 | You can include audio, images or videos with drag and drop, and format code blocks or logs with ``` tags.
46 | validations:
47 | required: true
48 |
49 | - type: input
50 | attributes:
51 | label: Minimal reproduction example (MRE)
52 | description: |
53 | Having reproducible issues is a prerequisite for contributors to be able to solve them.
54 | Include a link to minimal reproduction example using [this Google Colab notebook](https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/MRE_template.ipynb) as a starting point.
55 | validations:
56 | required: true
57 |
--------------------------------------------------------------------------------
/pyannote-audio/FAQ.md:
--------------------------------------------------------------------------------
1 |
2 | # Frequently Asked Questions
3 | - [Can I apply pretrained pipelines on audio already loaded in memory?](#can-i-apply-pretrained-pipelines-on-audio-already-loaded-in-memory)
4 | - [Can I use gated models (and pipelines) offline?](#can-i-use-gated-models-(and-pipelines)-offline)
5 | - [Does pyannote support streaming speaker diarization?](#does-pyannote-support-streaming-speaker-diarization)
6 | - [How can I improve performance?](#how-can-i-improve-performance)
7 | - [How does one spell and pronounce pyannote.audio?](#how-does-one-spell-and-pronounce-pyannoteaudio)
8 |
9 |
10 | ## Can I apply pretrained pipelines on audio already loaded in memory?
11 |
12 | Yes: read [this tutorial](tutorials/applying_a_pipeline.ipynb) until the end.
13 |
14 |
15 | ## Can I use gated models (and pipelines) offline?
16 |
17 | **Short answer**: yes, see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines.
18 |
19 | **Long answer**: gating models and pipelines allows [me](https://herve.niderb.fr) to know a bit more about `pyannote.audio` user base and eventually help me write grant proposals to make `pyannote.audio` even better. So, please fill gating forms as precisely as possible.
20 |
21 | For instance, before gating `pyannote/speaker-diarization`, I had no idea that so many people were relying on it in production. Hint: sponsors are more than welcome! Maintaining open source libraries is time consuming.
22 |
23 | That being said, this whole authentication process does not prevent you from using official `pyannote.audio` models offline (i.e. without going through the authentication process in every `docker run ...` or whatever you are using in production): see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines.
24 |
25 |
26 | ## Does pyannote support streaming speaker diarization?
27 |
28 | **Short answer:** not out of the box, no.
29 |
30 | **Long answer:** [I](https://herve.niderb.fr) am looking for sponsors to add this feature. In the meantime, [`diart`](https://github.com/juanmc2005/StreamingSpeakerDiarization) is the closest you can get from a streaming `pyannote.audio`. You might also be interested in [this blog post](https://herve.niderb.fr/fastpages/2021/08/05/Streaming-voice-activity-detection-with-pyannote.html) about streaming voice activity detection based on `pyannote.audio`.
31 |
32 |
33 | ## How can I improve performance?
34 |
35 | **Long answer:**
36 |
37 | 1. Manually annotate dozens of conversations as precisely as possible.
38 | 2. Separate them into train (80%), development (10%) and test (10%) subsets.
39 | 3. Setup the data for use with [`pyannote.database`](https://github.com/pyannote/pyannote-database#speaker-diarization).
40 | 4. Follow [this recipe](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb).
41 | 5. Enjoy.
42 |
43 | **Also:** [I am available](https://herve.niderb.fr) for contracting to help you with that.
44 |
45 |
46 | ## How does one spell and pronounce pyannote.audio?
47 |
48 | 📝 Written in lower case: `pyannote.audio` (or `pyannote` if you are lazy). Not `PyAnnote` nor `PyAnnotate` (sic).
49 | 📢 Pronounced like the french verb `pianoter`. `pi` like in `pi`ano, not `py` like in `py`thon.
50 | 🎹 `pianoter` means to play the piano (hence the logo 🤯).
51 |
52 |
53 |
54 | Generated by [FAQtory](https://github.com/willmcgugan/faqtory)
55 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/utils/reproducibility.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2023- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | # Context: https://github.com/pyannote/pyannote-audio/issues/1370
24 |
25 | import warnings
26 |
27 | import torch
28 |
29 |
30 | class ReproducibilityError(Exception):
31 | ...
32 |
33 |
34 | class ReproducibilityWarning(UserWarning):
35 | ...
36 |
37 |
38 | def raise_reproducibility(device: torch.device):
39 | if (device.type == "cuda") and (
40 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32
41 | ):
42 | raise ReproducibilityError(
43 | "Please disable TensorFloat-32 (TF32) by calling\n"
44 | " >>> import torch\n"
45 | " >>> torch.backends.cuda.matmul.allow_tf32 = False\n"
46 | " >>> torch.backends.cudnn.allow_tf32 = False\n"
47 | "or you might face reproducibility issues and obtain lower accuracy.\n"
48 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details."
49 | )
50 |
51 |
52 | def warn_reproducibility(device: torch.device):
53 | if (device.type == "cuda") and (
54 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32
55 | ):
56 | warnings.warn(
57 | ReproducibilityWarning(
58 | "Please disable TensorFloat-32 (TF32) by calling\n"
59 | " >>> import torch\n"
60 | " >>> torch.backends.cuda.matmul.allow_tf32 = False\n"
61 | " >>> torch.backends.cudnn.allow_tf32 = False\n"
62 | "or you might face reproducibility issues and obtain lower accuracy.\n"
63 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details."
64 | )
65 | )
66 |
67 |
68 | def fix_reproducibility(device: torch.device):
69 | if (device.type == "cuda") and (
70 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32
71 | ):
72 | torch.backends.cuda.matmul.allow_tf32 = False
73 | torch.backends.cudnn.allow_tf32 = False
74 | warnings.warn(
75 | ReproducibilityWarning(
76 | "TensorFloat-32 (TF32) has been disabled as it might lead to reproducibility issues and lower accuracy.\n"
77 | "It can be re-enabled by calling\n"
78 | " >>> import torch\n"
79 | " >>> torch.backends.cuda.matmul.allow_tf32 = True\n"
80 | " >>> torch.backends.cudnn.allow_tf32 = True\n"
81 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.\n"
82 | )
83 | )
84 |
--------------------------------------------------------------------------------
/recipes/diar_ssl/run_stage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Licensed under the MIT license.
4 | # Copyright 2024 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz)
5 |
6 | set -eu
7 | ulimit -n 2048
8 |
9 | # general setup
10 | stage=1
11 | recipe_root=/YOUR_PATH/DiariZen/recipes/diar_ssl
12 | exp_root=$recipe_root/exp
13 | conf_dir=$recipe_root/conf
14 |
15 | # training setup
16 | use_dual_opt=true # true for wavlm_updated_conformer.toml; false for the others
17 | train_conf=$conf_dir/wavlm_updated_conformer.toml
18 | # train_conf=$conf_dir/wavlm_frozen_conformer.toml
19 | # train_conf=$conf_dir/fbank_conformer.toml
20 | # train_conf=$conf_dir/pyannote_baseline.toml
21 |
22 | conf_name=`ls $train_conf | awk -F '/' '{print $NF}' | awk -F '.' '{print $1}'`
23 |
24 | # inference setup
25 | dtype=test
26 | data_dir=$recipe_root/data/AMI_AliMeeting_AISHELL4
27 | seg_duration=8
28 |
29 | # clustering setup
30 | clustering_method=AgglomerativeClustering
31 | ahc_threshold=0.70
32 | min_cluster_size=30
33 | infer_affix=_constrained_AHC_thres_${ahc_threshold}_mcs_${min_cluster_size}
34 |
35 | avg_ckpt_num=5
36 | val_metric=Loss # Loss or DER
37 | val_mode=best # [prev, best, center]
38 |
39 | # scoring setup
40 | collar=0
41 | REF_DIR=$data_dir
42 | dscore_dir=/YOUR_PATH/DiariZen/dscore
43 |
44 | # =======================================
45 | # =======================================
46 | if [ $stage -le 1 ]; then
47 | if (! $use_dual_opt); then
48 | echo "stage1: use single-opt for model training..."
49 | conda activate diarizen && CUDA_VISIBLE_DEVICES="0,1" accelerate launch \
50 | --num_processes 2 --main_process_port 1134 \
51 | run_single_opt.py -C $train_conf -M validate
52 | else
53 | echo "stage1: use dual-opt for model training..."
54 | conda activate diarizen && CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch \
55 | --num_processes 4 --main_process_port 1134 \
56 | run_dual_opt.py -C $train_conf -M train
57 | fi
58 | fi
59 |
60 | diarization_dir=$exp_root/$conf_name # can be replaced by our pre-trained models, e.g. diarization_dir=/YOUR_PATH/checkpoints/wavlm_updated_conformer
61 | config_dir=`ls $diarization_dir/*.toml | sort -r | head -n 1`
62 | embedding_model=/YOUR_PATH/pretrained/pyannote3/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin # it's necessary to have "pyannote" in your directory path
63 |
64 | if [ $stage -le 2 ]; then
65 | echo "stage2: model inference..."
66 | export CUDA_VISIBLE_DEVICES=0
67 |
68 | train_log=`du -h $diarization_dir/*.log | sort -rh | head -n 1 | awk '{print $NF}'`
69 | cat $train_log | grep 'Loss/DER' | awk -F ']:' '{print $NF}' > $diarization_dir/val_metric_summary.lst
70 |
71 | for dset in AMI AliMeeting AISHELL4; do
72 | conda activate diarizen && python infer_avg.py -C $config_dir \
73 | -i ${data_dir}/${dtype}/${dset}/wav.scp \
74 | -o ${diarization_dir}/infer$infer_affix/metric_${val_metric}_${val_mode}/avg_ckpt${avg_ckpt_num}/${dtype}/${dset} \
75 | --embedding_model $embedding_model \
76 | --avg_ckpt_num $avg_ckpt_num \
77 | --val_metric $val_metric \
78 | --val_mode $val_mode \
79 | --val_metric_summary $diarization_dir/val_metric_summary.lst \
80 | --seg_duration $seg_duration \
81 | --clustering_method $clustering_method \
82 | --ahc_threshold $ahc_threshold \
83 | --min_cluster_size $min_cluster_size
84 |
85 | echo "stage3: scoring..."
86 | SYS_DIR=${diarization_dir}/infer$infer_affix/metric_${val_metric}_${val_mode}/avg_ckpt${avg_ckpt_num}
87 | OUT_DIR=${SYS_DIR}/${dtype}/${dset}
88 | conda activate diarizen && python ${dscore_dir}/score.py \
89 | -r ${REF_DIR}/${dtype}/${dset}/rttm \
90 | -s $OUT_DIR/*.rttm --collar ${collar} \
91 | > $OUT_DIR/result_collar${collar}
92 | done
93 | fi
94 |
--------------------------------------------------------------------------------
/pyannote-audio/pyannote/audio/cli/evaluate.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2022- CNRS
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 |
24 | from typing import Optional
25 |
26 | import hydra
27 | from omegaconf import DictConfig
28 | from pyannote.database import FileFinder, ProtocolFile, registry
29 | from rich.progress import Progress
30 |
31 | from pyannote.audio import Inference, Model
32 | from pyannote.audio.pipelines.utils import get_devices
33 | from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate
34 | from pyannote.audio.utils.signal import binarize
35 |
36 |
37 | @hydra.main(config_path="evaluate_config", config_name="config")
38 | def evaluate(cfg: DictConfig) -> Optional[float]:
39 |
40 | # load pretrained model
41 | (device,) = get_devices(needs=1)
42 | model = Model.from_pretrained(cfg.model, device=device)
43 |
44 | # load databases into registry if it was specified
45 | if "registry" in cfg:
46 | for database_yml in cfg.registry.split(","):
47 | registry.load_database(database_yml)
48 |
49 | # load evaluation files
50 | protocol = registry.get_protocol(
51 | cfg.protocol, preprocessors={"audio": FileFinder()}
52 | )
53 |
54 | files = list(getattr(protocol, cfg.subset)())
55 |
56 | # load evaluation metric
57 | metric = DiscreteDiarizationErrorRate()
58 |
59 | with Progress() as progress:
60 |
61 | main_task = progress.add_task(protocol.name, total=len(files))
62 | file_task = progress.add_task("Processing", total=1.0)
63 |
64 | def progress_hook(completed: Optional[int] = None, total: Optional[int] = None):
65 | progress.update(file_task, completed=completed / total)
66 |
67 | inference = Inference(model, device=device)
68 | warm_up = cfg.warm_up / inference.duration
69 |
70 | def hypothesis(file: ProtocolFile):
71 | return Inference.trim(
72 | binarize(inference(file, hook=progress_hook)),
73 | warm_up=(warm_up, warm_up),
74 | )
75 |
76 | for file in files:
77 | progress.update(file_task, description=file["uri"])
78 | reference = file["annotation"]
79 | uem = file["annotated"]
80 | _ = metric(reference, hypothesis(file), uem=uem)
81 | progress.advance(main_task)
82 |
83 | report = metric.report(display=False)
84 |
85 | with open("report.txt", "w") as f:
86 |
87 | f.write(f"# Model: {cfg.model}\n")
88 | f.write(f"# Protocol: {protocol.name}\n")
89 | f.write(f"# Subset: {cfg.subset}\n")
90 | f.write("\n")
91 | report = report.to_string(
92 | index=True,
93 | sparsify=False,
94 | justify="right",
95 | float_format=lambda f: "{0:.2f}".format(f),
96 | )
97 | f.write(f"{report}")
98 |
99 |
100 | if __name__ == "__main__":
101 | evaluate()
102 |
--------------------------------------------------------------------------------
/diarizen/ckpt_utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the MIT license.
2 | # Copyright 2022 Brno University of Technology (author: Federico Landini, landini@fit.vut.cz)
3 | # Copyright 2024 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz)
4 |
5 | import os
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | import copy
12 |
13 | from typing import List, Dict
14 |
15 |
16 | def average_checkpoints(
17 | model: nn.Module,
18 | checkpoint_list: str,
19 | ) -> nn.Module:
20 | states_dict_list = []
21 | for ckpt_data in checkpoint_list:
22 | ckpt_path = ckpt_data['bin_path']
23 | copy_model = copy.deepcopy(model)
24 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
25 | copy_model.load_state_dict(checkpoint)
26 | states_dict_list.append(copy_model.state_dict())
27 | avg_state_dict = average_states(states_dict_list, torch.device('cpu'))
28 | avg_model = copy.deepcopy(model)
29 | avg_model.load_state_dict(avg_state_dict)
30 | return avg_model
31 |
32 | def average_states(
33 | states_list: List[Dict[str, torch.Tensor]],
34 | device: torch.device,
35 | ) -> List[Dict[str, torch.Tensor]]:
36 | qty = len(states_list)
37 | avg_state = states_list[0]
38 | for i in range(1, qty):
39 | for key in avg_state:
40 | avg_state[key] += states_list[i][key].to(device)
41 | for key in avg_state:
42 | avg_state[key] = avg_state[key] / qty
43 | return avg_state
44 |
45 | def load_metric_summary(metric_file, ckpt_path):
46 | with open(metric_file, "r") as f:
47 | lines = f.readlines()
48 | out_lst = []
49 | for line in lines:
50 | assert "Validation Loss/DER" in line
51 | epoch = line.split()[4].split(':')[0]
52 | Loss, DER = line.split()[-3], line.split()[-1]
53 | bin_path = f"epoch_{str(epoch).zfill(4)}/pytorch_model.bin"
54 | out_lst.append({
55 | 'epoch': int(epoch),
56 | 'bin_path': ckpt_path / bin_path,
57 | 'Loss': float(Loss),
58 | 'DER': float(DER)
59 | })
60 | return out_lst
61 |
62 | def average_ckpt(ckpt_dir, model, wavlm_only=False, val_metric='Loss', avg_ckpt_num=5, val_mode="best"):
63 | if os.path.isfile(ckpt_dir):
64 | ckpt_loaded = torch.load(ckpt_dir, map_location=torch.device('cpu'))
65 | if not wavlm_only:
66 | print(f"No model averaging | Fine-tune model from: {ckpt_dir}")
67 | model.load_state_dict(ckpt_loaded, strict=True)
68 | else:
69 | print(f"Only initialize wavlm model from: {ckpt_dir}")
70 | wavlm_prefix = "wavlm_model."
71 | wavlm_state_dict = {
72 | k[len(wavlm_prefix):]: v
73 | for k, v in ckpt_loaded.items()
74 | if k.startswith(wavlm_prefix)
75 | }
76 | model.wavlm_model.load_state_dict(wavlm_state_dict, strict=False)
77 | return model
78 |
79 | if 'checkpoints/epoch_' in ckpt_dir:
80 | print(f"No model averaging | Fine-tune model from certain epoch: {ckpt_dir.split('/')[-1]}")
81 | ckpt_loaded = torch.load(os.path.join(ckpt_dir, 'pytorch_model.bin'), map_location=torch.device('cpu'))
82 | model.load_state_dict(ckpt_loaded)
83 | return model
84 |
85 | assert val_metric == "Loss" and val_mode == "best"
86 | print(f'averaging best {avg_ckpt_num} checkpoints to the converged moment...')
87 |
88 | ckpt_dir = Path(ckpt_dir).expanduser().absolute()
89 | ckpt_path = ckpt_dir / 'checkpoints'
90 | val_metric_path = ckpt_dir / 'val_metric_summary.lst'
91 |
92 | val_metric_lst = load_metric_summary(val_metric_path, ckpt_path)
93 | val_metric_lst_sorted = sorted(val_metric_lst, key=lambda i: i[val_metric])
94 | best_val_metric_idx = val_metric_lst.index(val_metric_lst_sorted[0])
95 | val_metric_lst_out = val_metric_lst[
96 | best_val_metric_idx - avg_ckpt_num + 1 :
97 | best_val_metric_idx + 1
98 | ]
99 |
100 | return average_checkpoints(model, val_metric_lst_out)
101 |
--------------------------------------------------------------------------------