├── diarizen ├── __init__.py ├── models │ ├── __init__.py │ ├── module │ │ └── wav2vec2 │ │ │ ├── __init__.py │ │ │ ├── utils │ │ │ └── __init__.py │ │ │ ├── .DS_Store │ │ │ └── pruning_utils.py │ └── pruning │ │ └── model_distill_prune.py ├── pipelines │ ├── __init__.py │ └── utils.py ├── optimization.py ├── trainer_utils.py ├── noam_updater.py ├── logger.py └── ckpt_utils.py ├── pyannote-audio ├── version.txt ├── .github │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE │ │ ├── config.yml │ │ └── bug_report.yml │ ├── workflows │ │ ├── pypi.yml │ │ ├── test.yml │ │ ├── test_cli.yml │ │ └── doc.yml │ └── stale.yml ├── codecov.yml ├── doc │ ├── requirements.txt │ ├── source │ │ └── index.rst │ └── gen_docs.py ├── tutorials │ ├── assets │ │ ├── sample.wav │ │ ├── download-model.png │ │ ├── pyannote.diff.PNG │ │ ├── download-pipeline.png │ │ ├── pyannote.review.PNG │ │ ├── prodigy-pyannote.audio.png │ │ └── sample.rttm │ └── speaker_verification.ipynb ├── pyannote │ ├── audio │ │ ├── cli │ │ │ ├── train_config │ │ │ │ ├── model │ │ │ │ │ ├── XVectorMFCC.yaml │ │ │ │ │ ├── Pretrained.yaml │ │ │ │ │ ├── XVectorSincNet.yaml │ │ │ │ │ ├── DebugEmbedding.yaml │ │ │ │ │ ├── DebugSegmentation.yaml │ │ │ │ │ ├── PyanNet.yaml │ │ │ │ │ └── SSeRiouSS.yaml │ │ │ │ ├── trainer │ │ │ │ │ ├── fast_dev_run.yaml │ │ │ │ │ └── default.yaml │ │ │ │ ├── optimizer │ │ │ │ │ ├── Adan.yaml │ │ │ │ │ ├── Adam.yaml │ │ │ │ │ └── AdamW.yaml │ │ │ │ ├── preprocessor │ │ │ │ │ └── LowerTemporalResolution.yaml │ │ │ │ ├── scheduler │ │ │ │ │ ├── CyclicLR.yaml │ │ │ │ │ ├── CosineAnnealingWarmRestarts.yaml │ │ │ │ │ └── ReduceLROnPlateau.yaml │ │ │ │ ├── config.yaml │ │ │ │ ├── task │ │ │ │ │ ├── SpeakerDiarization.yaml │ │ │ │ │ ├── MultiLabelSegmentation.yaml │ │ │ │ │ ├── VoiceActivityDetection.yaml │ │ │ │ │ ├── SpeakerEmbedding.yaml │ │ │ │ │ └── OverlappedSpeechDetection.yaml │ │ │ │ ├── hydra │ │ │ │ │ └── default.yaml │ │ │ │ └── __init__.py │ │ │ ├── evaluate_config │ │ │ │ ├── config.yaml │ │ │ │ ├── hydra │ │ │ │ │ └── default.yaml │ │ │ │ └── __init__.py │ │ │ ├── config │ │ │ │ └── hydra │ │ │ │ │ └── default.yaml │ │ │ ├── __init__.py │ │ │ ├── pretrained.py │ │ │ ├── lr_schedulers │ │ │ │ ├── __init__.py │ │ │ │ ├── CosineAnnealingWarmRestarts.py │ │ │ │ ├── CyclicLR.py │ │ │ │ └── ReduceLROnPlateau.py │ │ │ └── evaluate.py │ │ ├── sample │ │ │ ├── sample.wav │ │ │ ├── sample.rttm │ │ │ └── __init__.py │ │ ├── utils │ │ │ ├── params.py │ │ │ ├── __init__.py │ │ │ ├── version.py │ │ │ ├── random.py │ │ │ ├── multi_task.py │ │ │ └── reproducibility.py │ │ ├── core │ │ │ └── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── segmentation │ │ │ │ └── __init__.py │ │ │ └── embedding │ │ │ │ ├── wespeaker │ │ │ │ ├── LICENSE.WeSpeaker │ │ │ │ └── convert.py │ │ │ │ └── __init__.py │ │ ├── tasks │ │ │ ├── embedding │ │ │ │ └── __init__.py │ │ │ ├── segmentation │ │ │ │ └── __init__.py │ │ │ └── __init__.py │ │ ├── torchmetrics │ │ │ ├── functional │ │ │ │ ├── __init__.py │ │ │ │ └── audio │ │ │ │ │ └── __init__.py │ │ │ ├── classification │ │ │ │ ├── __init__.py │ │ │ │ └── equal_error_rate.py │ │ │ ├── __init__.py │ │ │ └── audio │ │ │ │ └── __init__.py │ │ ├── augmentation │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── pipelines │ │ │ ├── __init__.py │ │ │ └── utils │ │ │ └── __init__.py │ └── __init__.py ├── questions │ ├── README.md │ ├── from_memory.question.md │ ├── pyannote.question.md │ ├── streaming.question.md │ ├── bad_performance.question.md │ └── offline.question.md ├── .gitattributes ├── .gitmodules ├── MANIFEST.in ├── environment.yaml ├── tests │ ├── test_import_lib.py │ ├── test_run_notebooks.py │ ├── utils │ │ ├── preview.py │ │ ├── probe_util_test.py │ │ ├── test_permutation.py │ │ └── test_powerset.py │ ├── tasks │ │ ├── test_specifications.py │ │ └── test_reproducibility.py │ ├── test_clustering.py │ ├── test_sample.py │ ├── conftest.py │ ├── io_test.py │ └── inference_test.py ├── faq.yml ├── .faq │ ├── FAQ.md │ └── suggest.md ├── requirements.txt ├── LICENSE ├── .pre-commit-config.yaml ├── .gitignore ├── setup.py ├── notebook │ ├── sharing.ipynb │ ├── freeze.ipynb │ └── augmentation.ipynb ├── setup.cfg └── FAQ.md ├── recipes ├── diar_ssl_pruning │ ├── data │ ├── convert_wavlm_from_hf.py │ ├── conf │ │ ├── dual_opt_common.toml │ │ ├── dual_opt_common_large.toml │ │ ├── s80_base.toml │ │ └── s80_large.toml │ └── README.md └── diar_ssl │ ├── data │ ├── AMI_AliMeeting_AISHELL4 │ │ ├── test │ │ │ ├── AMI │ │ │ │ ├── all.uem │ │ │ │ └── wav.scp │ │ │ ├── AISHELL4 │ │ │ │ ├── all.uem │ │ │ │ └── wav.scp │ │ │ └── AliMeeting │ │ │ │ ├── all.uem │ │ │ │ └── wav.scp │ │ └── dev │ │ │ ├── all.uem │ │ │ └── wav.scp │ └── NSF │ │ └── dev_sc │ │ └── all.uem │ ├── README.md │ ├── conf │ ├── pyannote_baseline.toml │ ├── wavlm_frozen_conformer.toml │ ├── fbank_conformer.toml │ └── wavlm_updated_conformer.toml │ └── run_stage.sh ├── .gitmodules ├── example └── EN2002a_30s.wav ├── requirements.txt ├── LICENSE ├── MODEL_LICENSE ├── .gitignore └── pyproject.toml /diarizen/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diarizen/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /diarizen/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyannote-audio/version.txt: -------------------------------------------------------------------------------- 1 | 3.1.1 2 | -------------------------------------------------------------------------------- /diarizen/models/module/wav2vec2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/data: -------------------------------------------------------------------------------- 1 | ../diar_ssl/data/ -------------------------------------------------------------------------------- /diarizen/models/module/wav2vec2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dscore"] 2 | path = dscore 3 | url = https://github.com/nryant/dscore.git 4 | -------------------------------------------------------------------------------- /example/EN2002a_30s.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/example/EN2002a_30s.wav -------------------------------------------------------------------------------- /pyannote-audio/.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [hbredin] 4 | -------------------------------------------------------------------------------- /pyannote-audio/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | patch: 4 | default: 5 | enabled: false 6 | -------------------------------------------------------------------------------- /pyannote-audio/doc/requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.10.0 2 | recommonmark 3 | Sphinx==3.0.4 4 | sphinx_rtd_theme==0.4.3 5 | -------------------------------------------------------------------------------- /diarizen/models/module/wav2vec2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/diarizen/models/module/wav2vec2/.DS_Store -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/sample.wav -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/XVectorMFCC.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.embedding.XVectorMFCC 3 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/sample/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/pyannote/audio/sample/sample.wav -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/Pretrained.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.cli.pretrained 3 | checkpoint: ??? 4 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/XVectorSincNet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.embedding.XVectorSincNet 3 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/trainer/fast_dev_run.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | fast_dev_run: True 4 | -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/download-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/download-model.png -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/pyannote.diff.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/pyannote.diff.PNG -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/download-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/download-pipeline.png -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/pyannote.review.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/pyannote.review.PNG -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/DebugEmbedding.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.embedding.debug.SimpleEmbeddingModel 3 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/evaluate_config/config.yaml: -------------------------------------------------------------------------------- 1 | model: ??? 2 | protocol: ??? 3 | warm_up: 0.0 4 | subset: test 5 | 6 | defaults: 7 | - hydra: default 8 | -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/prodigy-pyannote.audio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUTSpeechFIT/DiariZen/HEAD/pyannote-audio/tutorials/assets/prodigy-pyannote.audio.png -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/DebugSegmentation.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.segmentation.debug.SimpleSegmentationModel 3 | -------------------------------------------------------------------------------- /pyannote-audio/questions/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Questions 3 | 4 | Your questions should go in this directory. 5 | 6 | Question files should be named with the extension ".question.md". 7 | -------------------------------------------------------------------------------- /pyannote-audio/.gitattributes: -------------------------------------------------------------------------------- 1 | pyannote/audio/_version.py export-subst 2 | notebooks/* linguist-documentation 3 | tutorials/* linguist-documentation 4 | versioneer.py linguist-vendored 5 | -------------------------------------------------------------------------------- /pyannote-audio/.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tutorials/AMI-diarization-setup"] 2 | path = tutorials/AMI-diarization-setup 3 | url = https://github.com/pyannote/AMI-diarization-setup.git 4 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/optimizer/Adan.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: adan_pytorch.Adan 3 | lr: 1e-3 4 | betas: [0.1, 0.1, 0.001] 5 | weight_decay: 0.0 6 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/preprocessor/LowerTemporalResolution.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.utils.preprocessors.LowerTemporalResolution 3 | resolution: 0.1 4 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/optimizer/Adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: torch.optim.Adam 3 | lr: 1e-3 4 | betas: [0.9, 0.999] 5 | eps: 1e-08 6 | weight_decay: 0 7 | amsgrad: False 8 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/optimizer/AdamW.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: torch.optim.AdamW 3 | lr: 1e-3 4 | betas: [0.9, 0.999] 5 | eps: 1e-08 6 | weight_decay: 0.01 7 | amsgrad: False 8 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/scheduler/CyclicLR.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.cli.lr_schedulers.CyclicLR 3 | min_lr: 1e-8 4 | max_lr: 1e-3 5 | mode: triangular2 6 | patience: 50 7 | -------------------------------------------------------------------------------- /pyannote-audio/MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include pyannote *.py 2 | recursive-include pyannote *.yaml 3 | recursive-include pyannote *.wav 4 | recursive-include pyannote *.rttm 5 | global-exclude *.pyc 6 | global-exclude __pycache__ 7 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/scheduler/CosineAnnealingWarmRestarts.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.cli.lr_schedulers.CosineAnnealingWarmRestarts 3 | min_lr: 1e-8 4 | max_lr: 1e-3 5 | patience: 1 6 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/scheduler/ReduceLROnPlateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.cli.lr_schedulers.ReduceLROnPlateau 3 | min_lr: 1e-8 4 | max_lr: 1e-3 5 | factor: 0.5 6 | patience: 50 7 | -------------------------------------------------------------------------------- /pyannote-audio/environment.yaml: -------------------------------------------------------------------------------- 1 | name: pyannote-audio 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python==3.8.5 7 | - libsndfile==1.0.28 8 | - pip>=20.2 9 | - pip: 10 | - -r requirements.txt 11 | -------------------------------------------------------------------------------- /pyannote-audio/tests/test_import_lib.py: -------------------------------------------------------------------------------- 1 | from pyannote.audio.core.model import Model 2 | 3 | 4 | def test_import_lib(): 5 | """This is a dummy test, just to check 6 | if the lib can be successfully imported. 7 | """ 8 | assert Model is not None 9 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/config.yaml: -------------------------------------------------------------------------------- 1 | protocol: ??? 2 | 3 | defaults: 4 | - task: SpeakerDiarization 5 | - model: PyanNet 6 | - optimizer: Adam 7 | - scheduler: CosineAnnealingWarmRestarts 8 | - trainer: default 9 | - hydra: default 10 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/task/SpeakerDiarization.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.tasks.SpeakerDiarization 3 | duration: 5.0 4 | max_speakers_per_chunk: 3 5 | max_speakers_per_frame: 2 6 | batch_size: 32 7 | num_workers: 10 8 | pin_memory: False 9 | -------------------------------------------------------------------------------- /pyannote-audio/questions/from_memory.question.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Can I apply pretrained pipelines on audio already loaded in memory?" 3 | alt_titles: 4 | - "Can I apply models on an audio array?" 5 | --- 6 | 7 | Yes: read [this tutorial](tutorials/applying_a_pipeline.ipynb) until the end. 8 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/task/MultiLabelSegmentation.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.tasks.MultiLabelSegmentation 3 | duration: 3.0 4 | warm_up: 0.0 5 | balance: null 6 | weight: null 7 | batch_size: 32 8 | num_workers: null 9 | pin_memory: False 10 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/task/VoiceActivityDetection.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.tasks.VoiceActivityDetection 3 | duration: 3.0 4 | warm_up: 0.0 5 | balance: null 6 | weight: null 7 | batch_size: 32 8 | num_workers: null 9 | pin_memory: False 10 | -------------------------------------------------------------------------------- /diarizen/pipelines/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) 3 | 4 | def scp2path(scp_file): 5 | """ return path list """ 6 | lines = [line.strip().split()[1] for line in open(scp_file)] 7 | return lines 8 | -------------------------------------------------------------------------------- /pyannote-audio/faq.yml: -------------------------------------------------------------------------------- 1 | # FAQtory settings 2 | 3 | faq_url: "https://github.com/pyannote/pyannote-audio/blob/develop/FAQ.md" # Replace this with the URL to your FAQ.md! 4 | 5 | questions_path: "./questions" # Where questions should be stored 6 | output_path: "./FAQ.md" # Where FAQ.md should be generated 7 | templates_path: ".faq" # Path to templates 8 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/PyanNet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.segmentation.PyanNet 3 | sincnet: 4 | stride: 10 5 | lstm: 6 | hidden_size: 128 7 | num_layers: 2 8 | bidirectional: true 9 | monolithic: true 10 | dropout: 0.5 11 | linear: 12 | hidden_size: 128 13 | num_layers: 2 -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/task/SpeakerEmbedding.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.tasks.SupervisedRepresentationLearningWithArcFace 3 | min_duration: 2.0 4 | duration: 5.0 5 | num_classes_per_batch: 512 6 | num_chunks_per_class: 1 7 | margin: 2.0 8 | scale: 12.0 9 | num_workers: null 10 | pin_memory: False 11 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/params.py: -------------------------------------------------------------------------------- 1 | # TODO - make it depth-recursive 2 | # TODO - switch to Omegaconf maybe? 3 | 4 | from typing import Optional 5 | 6 | 7 | def merge_dict(defaults: dict, custom: Optional[dict] = None): 8 | params = dict(defaults) 9 | if custom is not None: 10 | params.update(custom) 11 | return params 12 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/model/SSeRiouSS.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.models.segmentation.SSeRiouSS 3 | wav2vec: WAVLM_BASE 4 | wav2vec_layer: -1 5 | lstm: 6 | hidden_size: 128 7 | num_layers: 4 8 | bidirectional: true 9 | monolithic: true 10 | dropout: 0.5 11 | linear: 12 | hidden_size: 128 13 | num_layers: 2 14 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/task/OverlappedSpeechDetection.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pyannote.audio.tasks.OverlappedSpeechDetection 3 | duration: 3.0 4 | warm_up: 0.0 5 | balance: null 6 | overlap: 7 | probability: 0.5 8 | snr_min: 0.0 9 | snr_max: 10.0 10 | weight: null 11 | batch_size: 32 12 | num_workers: null 13 | pin_memory: False 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | flit 3 | h5py 4 | joblib 5 | joblib 6 | jupyterlab 7 | tensorboard 8 | librosa 9 | matplotlib 10 | numpy==1.26.4 11 | onnxruntime-gpu 12 | openpyxl # for saving results to excel 13 | pandas 14 | pesq 15 | pre-commit 16 | pystoi 17 | pyyaml 18 | scipy 19 | soundfile 20 | tabulate 21 | toml 22 | torchinfo 23 | tqdm 24 | accelerate==1.6.0 25 | thop 26 | -------------------------------------------------------------------------------- /pyannote-audio/.faq/FAQ.md: -------------------------------------------------------------------------------- 1 | 2 | # Frequently Asked Questions 3 | 4 | {%- for question in questions %} 5 | - [{{ question.title }}](#{{ question.slug }}) 6 | {%- endfor %} 7 | 8 | 9 | {%- for question in questions %} 10 | 11 | 12 | ## {{ question.title }} 13 | 14 | {{ question.body }} 15 | 16 | {%- endfor %} 17 | 18 |
19 | 20 | Generated by [FAQtory](https://github.com/willmcgugan/faqtory) 21 | -------------------------------------------------------------------------------- /pyannote-audio/doc/source/index.rst: -------------------------------------------------------------------------------- 1 | ############## 2 | pyannote.audio 3 | ############## 4 | 5 | `pyannote.audio` is an open-source Python library that provides neural building blocks for speaker diarization. 6 | 7 | Installation 8 | ============ 9 | 10 | :: 11 | 12 | $ conda create -n pyannote python=3.10 13 | $ conda activate pyannote 14 | $ pip install pyannote.audio 15 | 16 | 17 | API documentation 18 | ================= 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | -------------------------------------------------------------------------------- /pyannote-audio/questions/pyannote.question.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "How does one spell and pronounce pyannote.audio?" 3 | alt_titles: 4 | - "Why the name of the library?" 5 | - "Why the logo of the library?" 6 | --- 7 | 8 | 📝 Written in lower case: `pyannote.audio` (or `pyannote` if you are lazy). Not `PyAnnote` nor `PyAnnotate` (sic). 9 | 📢 Pronounced like the french verb `pianoter`. `pi` like in `pi`ano, not `py` like in `py`thon. 10 | 🎹 `pianoter` means to play the piano (hence the logo 🤯). 11 | -------------------------------------------------------------------------------- /pyannote-audio/requirements.txt: -------------------------------------------------------------------------------- 1 | asteroid-filterbanks >=0.4 2 | einops >=0.6.0 3 | huggingface_hub >= 0.13.0 4 | lightning >= 2.0.1 5 | omegaconf >=2.1,<3.0 6 | pyannote.core >= 5.0.0 7 | pyannote.database >= 5.0.1 8 | pyannote.metrics >= 3.2 9 | pyannote.pipeline >= 3.0.1 10 | pytorch_metric_learning >= 2.1.0 11 | rich >= 12.0.0 12 | semver >= 3.0.0 13 | soundfile >= 0.12.1 14 | speechbrain >= 0.5.14 15 | tensorboardX >= 2.6 16 | torch >= 2.0.0 17 | torch_audiomentations >= 0.11.0 18 | torchaudio >= 2.1.1 19 | torchmetrics >= 0.11.0 20 | -------------------------------------------------------------------------------- /pyannote-audio/tests/test_run_notebooks.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | import papermill as pm 4 | 5 | 6 | def test_can_run_notebooks(): 7 | # Search for all notebooks in directory 8 | notebooks = glob("**/notebook/**/*.ipynb") 9 | for nb in notebooks: 10 | try: 11 | pm.execute_notebook( 12 | nb, "/dev/null", progress_bar=False, kernel_name="python" 13 | ) 14 | except Exception as e: 15 | # Which notebook caused the error 16 | raise Exception(nb, e) 17 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/all.uem: -------------------------------------------------------------------------------- 1 | EN2002a 1 0.000 2142.709375 2 | EN2002b 1 0.000 1786.848000 3 | EN2002c 1 0.000 2972.256000 4 | EN2002d 1 0.000 2209.898688 5 | ES2004a 1 0.000 1049.354687 6 | ES2004b 1 0.000 2345.493375 7 | ES2004c 1 0.000 2334.368000 8 | ES2004d 1 0.000 2222.290687 9 | IS1009a 1 0.000 838.833313 10 | IS1009b 1 0.000 2052.333312 11 | IS1009c 1 0.000 1820.833312 12 | IS1009d 1 0.000 1944.500000 13 | TS3003a 1 0.000 1505.642625 14 | TS3003b 1 0.000 2210.304000 15 | TS3003c 1 0.000 2570.000000 16 | TS3003d 1 0.000 2618.200000 17 | -------------------------------------------------------------------------------- /pyannote-audio/.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | 3 | contact_links: 4 | 5 | - name: Feature request 6 | url: https://github.com/pyannote/pyannote-audio/discussions 7 | about: Suggest an idea for this project. 8 | 9 | - name: Consulting 10 | url: https://herve.niderb.fr/consulting 11 | about: Using pyannote.audio in production? Make the most of it thanks to our consulting services. 12 | 13 | - name: Premium models 14 | url: https://forms.office.com/e/GdqwVgkZ5C 15 | about: We are considering selling premium models, extensions, or services around pyannote.audio. 16 | -------------------------------------------------------------------------------- /pyannote-audio/tutorials/assets/sample.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.690 0.430 speaker90 2 | SPEAKER sample 1 7.550 0.800 speaker91 3 | SPEAKER sample 1 8.320 1.700 speaker90 4 | SPEAKER sample 1 9.920 1.110 speaker91 5 | SPEAKER sample 1 10.570 4.130 speaker90 6 | SPEAKER sample 1 14.490 3.430 speaker91 7 | SPEAKER sample 1 18.050 3.440 speaker90 8 | SPEAKER sample 1 18.150 0.440 speaker91 9 | SPEAKER sample 1 21.780 6.720 speaker91 10 | SPEAKER sample 1 27.850 2.150 speaker90 11 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/sample/sample.rttm: -------------------------------------------------------------------------------- 1 | SPEAKER sample 1 6.690 0.430 speaker90 2 | SPEAKER sample 1 7.550 0.800 speaker91 3 | SPEAKER sample 1 8.320 1.700 speaker90 4 | SPEAKER sample 1 9.920 1.110 speaker91 5 | SPEAKER sample 1 10.570 4.130 speaker90 6 | SPEAKER sample 1 14.490 3.430 speaker91 7 | SPEAKER sample 1 18.050 3.440 speaker90 8 | SPEAKER sample 1 18.150 0.440 speaker91 9 | SPEAKER sample 1 21.780 6.720 speaker91 10 | SPEAKER sample 1 27.850 2.150 speaker90 11 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AISHELL4/all.uem: -------------------------------------------------------------------------------- 1 | L_R003S01C02 1 0.000 2362.945 2 | L_R003S02C02 1 0.000 2246.708 3 | L_R003S03C02 1 0.000 2268.015 4 | L_R003S04C02 1 0.000 2324.984 5 | L_R004S01C01 1 0.000 2195.433 6 | L_R004S02C01 1 0.000 2226.67 7 | L_R004S03C01 1 0.000 2227.768 8 | L_R004S06C01 1 0.000 2210.35 9 | M_R003S01C01 1 0.000 2282.21 10 | M_R003S02C01 1 0.000 2321.259 11 | M_R003S04C01 1 0.000 2256.727 12 | M_R003S05C01 1 0.000 2341.215 13 | S_R003S01C01 1 0.000 2220.529 14 | S_R003S02C01 1 0.000 2361.298 15 | S_R003S03C01 1 0.000 2279.261 16 | S_R003S04C01 1 0.000 2393.927 17 | S_R004S01C01 1 0.000 2345.908 18 | S_R004S02C01 1 0.000 2319.371 19 | S_R004S03C01 1 0.000 2327.119 20 | S_R004S04C01 1 0.000 2299.396 21 | -------------------------------------------------------------------------------- /pyannote-audio/.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Set up Python 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install setuptools wheel twine 21 | - name: Build and publish 22 | env: 23 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 24 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 25 | run: | 26 | python setup.py sdist bdist_wheel 27 | twine upload dist/* 28 | -------------------------------------------------------------------------------- /pyannote-audio/questions/streaming.question.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Does pyannote support streaming speaker diarization?" 3 | alt_titles: 4 | - "Is it possible to do realtime speaker diarization?" 5 | - "Can it process online audio buffers?" 6 | --- 7 | 8 | **Short answer:** not out of the box, no. 9 | 10 | **Long answer:** [I](https://herve.niderb.fr) am looking for sponsors to add this feature. In the meantime, [`diart`](https://github.com/juanmc2005/StreamingSpeakerDiarization) is the closest you can get from a streaming `pyannote.audio`. You might also be interested in [this blog post](https://herve.niderb.fr/fastpages/2021/08/05/Streaming-voice-activity-detection-with-pyannote.html) about streaming voice activity detection based on `pyannote.audio`. 11 | -------------------------------------------------------------------------------- /pyannote-audio/.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 180 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /pyannote-audio/questions/bad_performance.question.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "How can I improve performance?" 3 | alt_titles: 4 | - "Pretrained pipelines do not produce good results on my data. What can I do?" 5 | - "It does not work! Help me!" 6 | --- 7 | 8 | **Long answer:** 9 | 10 | 1. Manually annotate dozens of conversations as precisely as possible. 11 | 2. Separate them into train (80%), development (10%) and test (10%) subsets. 12 | 3. Setup the data for use with [`pyannote.database`](https://github.com/pyannote/pyannote-database#speaker-diarization). 13 | 4. Follow [this recipe](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb). 14 | 5. Enjoy. 15 | 16 | **Also:** [I am available](https://herve.niderb.fr) for contracting to help you with that. 17 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AliMeeting/all.uem: -------------------------------------------------------------------------------- 1 | R8002_M8002_MS802 1 0.000 2066.52 2 | R8002_M8003_MS803 1 0.000 2042.126 3 | R8004_M8005_MS803 1 0.000 2093.014 4 | R8004_M8006_MS805 1 0.000 1965.215 5 | R8005_M8007_MS806 1 0.000 1869.669 6 | R8005_M8008_MS806 1 0.000 1960.822 7 | R8005_M8009_MS802 1 0.000 1776.99 8 | R8006_M8012_MS803 1 0.000 1832.183 9 | R8008_M8014_MS807 1 0.000 1937.092 10 | R8008_M8015_MS808 1 0.000 1896.185 11 | R8008_M8016_MS808 1 0.000 1838.98 12 | R8008_M8017_MS808 1 0.000 1820.052 13 | R8009_M8021_MS810 1 0.000 2008.399 14 | R8009_M8022_MS810 1 0.000 2018.56 15 | R8009_M8023_MS811 1 0.000 1962.73 16 | R8009_M8024_MS811 1 0.000 1854.474 17 | R8009_M8025_MS811 1 0.000 2065.057 18 | R8009_M8026_MS812 1 0.000 1983.113 19 | R8009_M8027_MS812 1 0.000 1881.248 20 | R8009_M8028_MS812 1 0.000 1923.043 21 | -------------------------------------------------------------------------------- /pyannote-audio/tests/utils/preview.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from IPython.display import Audio 3 | 4 | from pyannote.audio.utils.preview import listen 5 | from pyannote.core import Segment 6 | from pyannote.database import FileFinder, get_protocol 7 | 8 | 9 | def test_file(): 10 | protocol = get_protocol( 11 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()} 12 | ) 13 | return next(protocol.train()) 14 | 15 | 16 | def test_returns_audio_object(): 17 | audio_file = test_file() 18 | ipython_audio = listen(audio_file) 19 | assert isinstance(ipython_audio, Audio) 20 | 21 | 22 | def test_can_crop(): 23 | audio_file = test_file() 24 | listen(audio_file, Segment(0, 1)) 25 | 26 | 27 | def test_fail_crop_too_large(): 28 | with pytest.raises(ValueError): 29 | audio_file = test_file() 30 | duration = audio_file.duration 31 | listen(audio_file, Segment(0, duration * 2)) 32 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/config/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | run: 4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 5 | 6 | sweep: 7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 8 | subdir: ${hydra.job.num} 9 | 10 | output_subdir: "" 11 | 12 | help: 13 | app_name: pyannote-audio-train 14 | 15 | # Help header, customize to describe your app to your users 16 | header: == ${hydra.help.app_name} == 17 | 18 | footer: |- 19 | Powered by Hydra (https://hydra.cc) 20 | Use --hydra-help to view Hydra specific help 21 | 22 | template: |- 23 | ${hydra.help.header} 24 | 25 | pyannote-audio-train protocol={protocol_name} 26 | task={task} task.param=... 27 | model={model} model.param=... 28 | optimizer={optimizer} optimizer.param=... 29 | scheduler={scheduler} scheduler.param=... 30 | 31 | ${hydra.help.footer} 32 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/evaluate_config/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | run: 4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 5 | 6 | sweep: 7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 8 | subdir: ${hydra.job.num} 9 | 10 | output_subdir: "" 11 | 12 | help: 13 | app_name: pyannote-audio-eval 14 | 15 | # Help header, customize to describe your app to your users 16 | header: == ${hydra.help.app_name} == 17 | 18 | footer: |- 19 | Powered by Hydra (https://hydra.cc) 20 | Use --hydra-help to view Hydra specific help 21 | 22 | template: |- 23 | ${hydra.help.header} 24 | 25 | pyannote-audio-eval registry={path_to_database.yml} 26 | protocol={protocol_name} 27 | subset={test | development | train} 28 | model={path_to_pretrained_model} 29 | warm_up={warm_up_duration_in_seconds} 30 | 31 | ${hydra.help.footer} 32 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | run: 4 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 5 | 6 | sweep: 7 | dir: ${protocol}/${now:%Y-%m-%dT%H:%M:%S.%fZ} 8 | subdir: ${hydra.job.num} 9 | 10 | output_subdir: "" 11 | 12 | help: 13 | app_name: pyannote-audio-train 14 | 15 | # Help header, customize to describe your app to your users 16 | header: == ${hydra.help.app_name} == 17 | 18 | footer: |- 19 | Powered by Hydra (https://hydra.cc) 20 | Use --hydra-help to view Hydra specific help 21 | 22 | template: |- 23 | ${hydra.help.header} 24 | 25 | pyannote-audio-train protocol={protocol_name} 26 | +task={task} task.param=... 27 | +model={model} model.param=... 28 | optimizer={optimizer} optimizer.param=... 29 | scheduler={scheduler} scheduler.param=... 30 | 31 | ${hydra.help.footer} 32 | -------------------------------------------------------------------------------- /pyannote-audio/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | branches: [develop] 8 | 9 | jobs: 10 | build: 11 | timeout-minutes: 20 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: [3.8, 3.9, "3.10"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install libsndfile 24 | if: matrix.os == 'ubuntu-latest' 25 | run: | 26 | sudo apt-get update 27 | sudo apt-get install libsndfile1 28 | - name: Install pyannote.audio 29 | run: | 30 | pip install -e .[dev,testing] 31 | - name: Test with pytest 32 | run: | 33 | pytest -k "not test_cli.py" 34 | -------------------------------------------------------------------------------- /pyannote-audio/.github/workflows/test_cli.yml: -------------------------------------------------------------------------------- 1 | name: CLI tests 2 | 3 | on: 4 | push: 5 | branches: [develop] 6 | pull_request: 7 | branches: [develop] 8 | 9 | jobs: 10 | build: 11 | timeout-minutes: 20 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: ["3.10"] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install libsndfile 24 | if: matrix.os == 'ubuntu-latest' 25 | run: | 26 | sudo apt-get update 27 | sudo apt-get install libsndfile1 28 | - name: Install pyannote.audio 29 | run: | 30 | pip install -e .[dev,testing,cli] 31 | - name: Test with pytest 32 | run: | 33 | pytest tests/test_cli.py 34 | -------------------------------------------------------------------------------- /pyannote-audio/tests/tasks/test_specifications.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyannote.database import FileFinder, get_protocol 3 | 4 | from pyannote.audio.core.model import Model 5 | from pyannote.audio.core.task import UnknownSpecificationsError 6 | from pyannote.audio.tasks import SpeakerDiarization 7 | 8 | 9 | @pytest.fixture() 10 | def protocol(): 11 | return get_protocol( 12 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()} 13 | ) 14 | 15 | 16 | def test_unknown_specifications_error_raised_on_non_setup_task(protocol): 17 | task = SpeakerDiarization(protocol=protocol) 18 | with pytest.raises(UnknownSpecificationsError): 19 | _ = task.specifications 20 | 21 | 22 | def test_unknown_specifications_error_raised_on_non_setup_model_task(protocol): 23 | task = SpeakerDiarization(protocol=protocol) 24 | model = Model.from_pretrained("pyannote/ci-segmentation") 25 | model.task = task 26 | with pytest.raises(UnknownSpecificationsError): 27 | _ = model.specifications 28 | -------------------------------------------------------------------------------- /pyannote-audio/tests/test_clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pyannote.audio.pipelines.clustering import AgglomerativeClustering 4 | 5 | 6 | def test_agglomerative_clustering_num_cluster(): 7 | """ 8 | Make sure AgglomerativeClustering doesn't "over-merge" clusters when initial 9 | clustering already matches target num_clusters, cf 10 | https://github.com/pyannote/pyannote-audio/issues/1525 11 | """ 12 | 13 | # 2 embeddings different enough 14 | embeddings = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 2.0, 1.0, 2.0]]) 15 | 16 | # clustering with params that should yield 1 cluster per embedding 17 | clustering = AgglomerativeClustering().instantiate( 18 | { 19 | "method": "centroid", 20 | "min_cluster_size": 0, 21 | "threshold": 0.0, 22 | } 23 | ) 24 | 25 | # request 2 clusters 26 | clusters = clustering.cluster( 27 | embeddings=embeddings, min_clusters=2, max_clusters=2, num_clusters=2 28 | ) 29 | assert np.array_equal(clusters, np.array([0, 1])) 30 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.Trainer 3 | accelerator: auto 4 | accumulate_grad_batches: 1 5 | benchmark: null # TODO: automatically set to True when using fixed duration chunks 6 | deterministic: False 7 | check_val_every_n_epoch: 1 8 | devices: auto 9 | detect_anomaly: False 10 | enable_checkpointing: True 11 | enable_model_summary: True 12 | enable_progress_bar: True 13 | fast_dev_run: False 14 | gradient_clip_val: null 15 | gradient_clip_algorithm: norm 16 | limit_predict_batches: 1.0 17 | limit_test_batches: 1.0 18 | limit_train_batches: 1.0 19 | limit_val_batches: 1.0 20 | log_every_n_steps: 50 21 | max_epochs: 1000 22 | max_steps: -1 23 | max_time: null 24 | min_epochs: 1 25 | min_steps: null 26 | num_nodes: 1 27 | num_sanity_val_steps: 2 28 | overfit_batches: 0.0 29 | precision: 32 30 | profiler: null 31 | reload_dataloaders_every_n_epochs: 0 32 | use_distributed_sampler: True # TODO: check what this does exactly 33 | strategy: auto 34 | sync_batchnorm: False 35 | val_check_interval: 1.0 36 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AMI/wav.scp: -------------------------------------------------------------------------------- 1 | IS1009a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009a.wav 2 | IS1009b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009b.wav 3 | IS1009c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009c.wav 4 | IS1009d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/IS1009d.wav 5 | ES2004a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004a.wav 6 | ES2004b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004b.wav 7 | ES2004c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004c.wav 8 | ES2004d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/ES2004d.wav 9 | TS3003a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003a.wav 10 | TS3003b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003b.wav 11 | TS3003c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003c.wav 12 | TS3003d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/TS3003d.wav 13 | EN2002a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002a.wav 14 | EN2002b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002b.wav 15 | EN2002c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002c.wav 16 | EN2002d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/EN2002d.wav 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BUT Speech@FIT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CNRS 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | on: 3 | push: 4 | branches: 5 | - master 6 | 7 | jobs: 8 | build-and-deploy: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | max-parallel: 4 12 | matrix: 13 | python-version: ["3.9"] 14 | 15 | steps: 16 | - uses: actions/checkout@v1 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install . 25 | pip install -r doc/requirements.txt 26 | - name: Build documentation 27 | run: | 28 | make --directory=doc html 29 | touch ./doc/build/html/.nojekyll 30 | - name: Deploy 31 | env: 32 | ACTIONS_DEPLOY_KEY: ${{ secrets.ACTIONS_DEPLOY_KEY }} 33 | PUBLISH_BRANCH: gh-pages 34 | PUBLISH_DIR: ./doc/build/html 35 | SCRIPT_MODE: true 36 | run: | 37 | wget https://raw.githubusercontent.com/peaceiris/actions-gh-pages/v2/entrypoint.sh 38 | bash ./entrypoint.sh 39 | -------------------------------------------------------------------------------- /pyannote-audio/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^docs/conf.py' 2 | 3 | repos: 4 | # # Clean Notebooks 5 | # - repo: https://github.com/kynan/nbstripout 6 | # rev: master 7 | # hooks: 8 | # - id: nbstripout 9 | # Format Code 10 | - repo: https://github.com/ambv/black 11 | rev: 22.3.0 12 | hooks: 13 | - id: black 14 | 15 | # Sort imports 16 | - repo: https://github.com/PyCQA/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | args: ["--profile", "black"] 21 | 22 | # Formatting, Whitespace, etc 23 | - repo: https://github.com/pre-commit/pre-commit-hooks 24 | rev: v2.2.3 25 | hooks: 26 | - id: trailing-whitespace 27 | - id: check-added-large-files 28 | args: ['--maxkb=1000'] 29 | - id: check-ast 30 | - id: check-json 31 | - id: check-merge-conflict 32 | - id: check-xml 33 | - id: check-yaml 34 | - id: debug-statements 35 | - id: end-of-file-fixer 36 | - id: requirements-txt-fixer 37 | - id: mixed-line-ending 38 | args: ['--fix=no'] 39 | - id: flake8 40 | args: ['--ignore=E203,E501,F811,E712,W503'] 41 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/core/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/models/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/tasks/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/train_config/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/tasks/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/evaluate_config/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/functional/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/functional/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/convert_wavlm_from_hf.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) 3 | 4 | import os 5 | import argparse 6 | 7 | from diarizen.models.pruning.utils import convert_wavlm 8 | 9 | def run(args): 10 | hf_dir = os.path.abspath(args.hf_dir) 11 | out_dir = os.path.abspath(args.out_dir) 12 | 13 | if not os.path.isdir(hf_dir): 14 | raise FileNotFoundError(f"HuggingFace directory does not exist: {hf_dir}") 15 | 16 | os.makedirs(out_dir, exist_ok=True) 17 | convert_wavlm(hf_dir, out_dir) 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser( 21 | description="Convert HuggingFace WavLM model to custom format " 22 | "(e.g. from /pre-trained/HF/wavlm-base-plus)" 23 | ) 24 | parser.add_argument( 25 | "hf_dir", type=str, 26 | help="Path to the HuggingFace WavLM directory containing config.json and pytorch_model.bin." 27 | ) 28 | parser.add_argument( 29 | "out_dir", type=str, 30 | help="Path to output directory where the converted model will be saved." 31 | ) 32 | 33 | args = parser.parse_args() 34 | run(args) -------------------------------------------------------------------------------- /pyannote-audio/pyannote/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | __import__("pkg_resources").declare_namespace(__name__) 24 | -------------------------------------------------------------------------------- /recipes/diar_ssl/README.md: -------------------------------------------------------------------------------- 1 | # DiariZen EEND Module 2 | This directory contains scripts for DiariZen EEND module training and global inference for speaker diarization. 3 | 4 | 5 | ## Results (collar=0s) 6 | | System | Features | AMI | AISHELL-4 | AliMeeting | 7 | |:------------|:----------------:|:------:|:------------:|:------------:| 8 | | [Pyannote v3.1](https://github.com/pyannote/pyannote-audio) | SincNet | 22.4 | 12.2 | 24.4 | 9 | | DiariZen | Fbank | 19.7 | 12.5 | 21.0 | 10 | | | WavLM-frozen | 17.0 | 11.7 | 19.9 | 11 | | | WavLM-updated | **15.4** | **11.7** | **17.6** | 12 | 13 | 14 | ## Citation 15 | If you found this work helpful, please consider citing: 16 | J. Han, F. Landini, J. Rohdin, A. Silnova, M. Diez, and L. Burget, [Leveraging Self-Supervised Learning for Speaker Diarization](https://arxiv.org/pdf/2409.09408), in Proc. ICASSP, 2025. 17 | ``` 18 | @inproceedings{han2025leveraging, 19 | title={Leveraging self-supervised learning for speaker diarization}, 20 | author={Han, Jiangyu and Landini, Federico and Rohdin, Johan and Silnova, Anna and Diez, Mireia and Burget, Luk{\'a}{\v{s}}}, 21 | booktitle={Proc. ICASSP}, 22 | year={2025} 23 | } 24 | 25 | ``` 26 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .pretrained import pretrained 24 | 25 | __all__ = [ 26 | "pretrained", 27 | ] 28 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .mix import MixSpeakerDiarization 25 | 26 | __all__ = ["MixSpeakerDiarization"] 27 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/models/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .PyanNet import PyanNet 24 | from .SSeRiouSS import SSeRiouSS 25 | 26 | __all__ = ["PyanNet", "SSeRiouSS"] 27 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .equal_error_rate import EqualErrorRate 25 | 26 | __all__ = [ 27 | "EqualErrorRate", 28 | ] 29 | -------------------------------------------------------------------------------- /pyannote-audio/questions/offline.question.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Can I use gated models (and pipelines) offline?" 3 | alt_titles: 4 | - "Why does one need to authenticate to access the pretrained models?" 5 | - "Can I use pyannote.audio pretrained pipelines without the Hugginface token?" 6 | - "How can I solve the permission issue?" 7 | --- 8 | 9 | **Short answer**: yes, see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines. 10 | 11 | **Long answer**: gating models and pipelines allows [me](https://herve.niderb.fr) to know a bit more about `pyannote.audio` user base and eventually help me write grant proposals to make `pyannote.audio` even better. So, please fill gating forms as precisely as possible. 12 | 13 | For instance, before gating `pyannote/speaker-diarization`, I had no idea that so many people were relying on it in production. Hint: sponsors are more than welcome! Maintaining open source libraries is time consuming. 14 | 15 | That being said, this whole authentication process does not prevent you from using official `pyannote.audio` models offline (i.e. without going through the authentication process in every `docker run ...` or whatever you are using in production): see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines. 16 | -------------------------------------------------------------------------------- /pyannote-audio/tests/test_sample.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 2024- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | def test_sample(): 25 | from pyannote.audio.sample import SAMPLE_FILE 26 | 27 | assert "annotation" in SAMPLE_FILE 28 | assert "annotated" in SAMPLE_FILE 29 | -------------------------------------------------------------------------------- /pyannote-audio/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | .env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | docs/_build/ 58 | 59 | # PyBuilder 60 | target/ 61 | 62 | #Ipython Notebook 63 | .ipynb_checkpoints 64 | 65 | notebooks 66 | 67 | experiments 68 | *~ 69 | 70 | *.npy 71 | *.pt 72 | *events.out.tfevents* 73 | *.csv 74 | 75 | # PyCharm 76 | .idea/ 77 | 78 | gh-pages 79 | gh-pages.pub 80 | 81 | *.zip 82 | .mypy_cache/ 83 | .vscode/ 84 | 85 | **/lightning_logs/** 86 | 87 | # Version Output 88 | pyannote/audio/version.py 89 | 90 | # vim 91 | .vim 92 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/pretrained.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from typing import Text 25 | from pyannote.audio import Model 26 | 27 | 28 | def pretrained(checkpoint: Text): 29 | return Model.from_pretrained(checkpoint, map_location=lambda storage, loc: storage) 30 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/models/embedding/wespeaker/LICENSE.WeSpeaker: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) 2 | 2022 Zhengyang Chen (chenzhengyang117@gmail.com) 3 | 2023 Bing Han (hanbing97@sjtu.edu.cn) 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | File `resnet.py` has been borrowed from WeSpeaker that is available under the Apache License, Version 2.0. 18 | 19 | The original file is available at https://github.com/wenet-e2e/wespeaker/blob/c20d765295359e681321625fbefc1a02e8794163/wespeaker/models/resnet.py 20 | 21 | Neither Shuai Wang (@wsstriving on Github) nor myself (Hervé Bredin, or @hbredin on Github) are lawyers, but we both agreed that putting this license file in this directory is enough to comply with the license. See https://github.com/pyannote/pyannote-audio/issues/1537#issuecomment-1808029836. If you know better about this potential MIT/Apache 2.0 compatibility issue, please let us know. 22 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/lr_schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .CosineAnnealingWarmRestarts import CosineAnnealingWarmRestarts 25 | from .CyclicLR import CyclicLR 26 | from .ReduceLROnPlateau import ReduceLROnPlateau 27 | 28 | __all__ = ["ReduceLROnPlateau", "CyclicLR", "CosineAnnealingWarmRestarts"] 29 | -------------------------------------------------------------------------------- /pyannote-audio/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | def pytest_sessionstart(session): 25 | """ 26 | Called after the Session object has been created and 27 | before performing collection and entering the run test loop. 28 | """ 29 | 30 | from pyannote.database import registry 31 | 32 | registry.load_database("tests/data/database.yml") 33 | -------------------------------------------------------------------------------- /pyannote-audio/doc/gen_docs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script will generate the rst docs for the api 3 | """ 4 | 5 | import os 6 | from os import path 7 | 8 | bp = breakpoint 9 | 10 | 11 | def capitalise(s): 12 | news = "" 13 | for word in s.split("_"): 14 | news += word.capitalize() 15 | return news 16 | 17 | 18 | def process_dir(level, p): 19 | md = "" 20 | basename = path.basename(p) 21 | 22 | title = capitalise(basename) 23 | md += f"{'#'*level} {title}\n\n" 24 | subdirs = os.listdir(p) 25 | 26 | for f in subdirs: 27 | m = path.join(subdir, f) 28 | if path.isdir(m): 29 | md += process_dir(level + 1, path.join(p, f)) 30 | else: 31 | if "__" in f: 32 | continue 33 | module = m[3:].replace("/", ".")[:-3] 34 | md += f""" 35 | ```eval_rst 36 | .. automodule:: {module} 37 | :members: 38 | 39 | ``` 40 | 41 | """ 42 | return md 43 | 44 | 45 | DIR = "../pyannote/audio" 46 | 47 | for module in os.listdir(DIR): 48 | # Each folder will become and rst file 49 | # Each file/folder will have a # prepended to it 50 | # Recursively we will add another # each level 51 | 52 | # Initialise Markdown 53 | md = "" 54 | 55 | subdir = path.join(DIR, module) 56 | 57 | # Skip if not directory 58 | if not path.isdir(subdir) or "__" in module: 59 | continue 60 | 61 | md += process_dir(1, subdir) 62 | with open(f"./source/api/{module}.md", "w") as f: 63 | f.write(md) 64 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | try: 24 | from .version import __version__, git_version # noqa: F401 25 | except ImportError: 26 | pass 27 | 28 | 29 | from .core.inference import Inference 30 | from .core.io import Audio 31 | from .core.model import Model 32 | from .core.pipeline import Pipeline 33 | 34 | __all__ = ["Audio", "Model", "Inference", "Pipeline"] 35 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AISHELL4/wav.scp: -------------------------------------------------------------------------------- 1 | S_R003S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S02C01.wav 2 | L_R004S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S02C01.wav 3 | S_R003S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S04C01.wav 4 | L_R003S04C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S04C02.wav 5 | S_R004S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S01C01.wav 6 | L_R003S02C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S02C02.wav 7 | S_R003S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S03C01.wav 8 | S_R004S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S02C01.wav 9 | L_R004S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S03C01.wav 10 | L_R004S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S06C01.wav 11 | L_R004S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R004S01C01.wav 12 | M_R003S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S02C01.wav 13 | L_R003S03C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S03C02.wav 14 | S_R004S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S03C01.wav 15 | L_R003S01C02 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/L_R003S01C02.wav 16 | S_R004S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R004S04C01.wav 17 | M_R003S04C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S04C01.wav 18 | M_R003S05C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S05C01.wav 19 | M_R003S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/M_R003S01C01.wav 20 | S_R003S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/S_R003S01C01.wav 21 | -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | MODEL LICENSE 2 | ============= 3 | 4 | The pre-trained model weights released in this repository ("the Models") 5 | are licensed under the Creative Commons Attribution-NonCommercial 4.0 6 | International License (CC BY-NC 4.0). 7 | 8 | Full legal text of the license is available at: 9 | https://creativecommons.org/licenses/by-nc/4.0/legalcode 10 | 11 | Summary of terms: 12 | - Attribution (BY): You must give appropriate credit, provide a link to the license, 13 | and indicate if changes were made. 14 | - NonCommercial (NC): You may not use the Models for commercial purposes. 15 | 16 | You are free to: 17 | - Use the Models for academic research and other non-commercial purposes. 18 | - Share and redistribute the Models, provided that you include proper attribution 19 | to the authors of DiariZen and retain this license notice. 20 | - Modify or adapt the Models for non-commercial use. 21 | 22 | You must not: 23 | - Use the Models or any derivatives thereof for commercial purposes, 24 | including but not limited to products, services, or paid offerings. 25 | - Remove or alter this license notice in any copies of the Models. 26 | 27 | Reason for this restriction: 28 | The training data used for the Models includes datasets licensed under terms 29 | that forbid commercial usage and/or derivative works 30 | (e.g., CC BY-NC, CC BY-NC-ND, or Research Only licenses). 31 | Therefore, the Models cannot be used commercially. 32 | 33 | Note: 34 | The source code in this repository is licensed separately under the MIT License 35 | (see LICENSE file). This MODEL_LICENSE applies only to the released pre-trained 36 | model weights. 37 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/models/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .wespeaker import ( 25 | WeSpeakerResNet34, 26 | WeSpeakerResNet152, 27 | WeSpeakerResNet221, 28 | WeSpeakerResNet293, 29 | ) 30 | from .xvector import XVectorMFCC, XVectorSincNet 31 | 32 | __all__ = [ 33 | "XVectorSincNet", 34 | "XVectorMFCC", 35 | "WeSpeakerResNet34", 36 | "WeSpeakerResNet152", 37 | "WeSpeakerResNet221", 38 | "WeSpeakerResNet293", 39 | ] 40 | -------------------------------------------------------------------------------- /diarizen/optimization.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/optimization.py 3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com) 4 | 5 | from functools import partial 6 | 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LambdaLR 9 | 10 | 11 | def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int): 12 | if current_step < num_warmup_steps: 13 | return float(current_step) / float(max(1.0, num_warmup_steps)) 14 | return 1.0 15 | 16 | 17 | def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1): 18 | lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps) 19 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 20 | 21 | 22 | def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int): 23 | if current_step < num_warmup_steps: 24 | return float(current_step) / float(max(1, num_warmup_steps)) 25 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 26 | 27 | 28 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 29 | lr_lambda = partial( 30 | _get_linear_schedule_with_warmup_lr_lambda, 31 | num_warmup_steps=num_warmup_steps, 32 | num_training_steps=num_training_steps, 33 | ) 34 | return LambdaLR(optimizer, lr_lambda, last_epoch) 35 | -------------------------------------------------------------------------------- /pyannote-audio/.faq/suggest.md: -------------------------------------------------------------------------------- 1 | Thank you for your issue. 2 | 3 | {%- if questions -%} 4 | {% if questions|length == 1 %} 5 | We found the following entry in the [FAQ]({{ faq_url }}) which you may find helpful: 6 | {%- else %} 7 | We found the following entries in the [FAQ]({{ faq_url }}) which you may find helpful: 8 | {%- endif %} 9 | 10 | {% for question in questions %} 11 | - [{{ question.title }}]({{ faq_url }}#{{ question.slug }}) 12 | {%- endfor %} 13 | 14 | {%- else -%} 15 | You might want to check the [FAQ]({{ faq_url }}) if you haven't done so already. 16 | {%- endif %} 17 | 18 | Feel free to close this issue if you found an answer in the FAQ. 19 | 20 | If your issue is a feature request, please read [this](https://xyproblem.info/) first and update your request accordingly, if needed. 21 | 22 | If your issue is a bug report, please provide a [minimum reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) as a link to a self-contained [Google Colab](https://colab.research.google.com/) notebook containing everthing needed to reproduce the bug: 23 | - installation 24 | - data preparation 25 | - model download 26 | - etc. 27 | 28 | Providing an MRE will increase your chance of getting an answer from the community (either maintainers or other power users). 29 | 30 | Companies relying on `pyannote.audio` in production may contact [me](https://herve.niderb.fr) via email regarding: 31 | * paid scientific consulting around speaker diarization and speech processing in general; 32 | * custom models and tailored features (via the local tech transfer office). 33 | 34 | > This is an automated reply, generated by [FAQtory](https://github.com/willmcgugan/faqtory) 35 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/dev/all.uem: -------------------------------------------------------------------------------- 1 | ES2011a 1 0.000 1113.845375 2 | ES2011b 1 0.000 1581.269375 3 | ES2011c 1 0.000 1616.064000 4 | ES2011d 1 0.000 1982.325375 5 | IB4001 1 0.000 1780.650688 6 | IB4002 1 0.000 1882.368000 7 | IB4003 1 0.000 2023.253375 8 | IB4004 1 0.000 2392.832000 9 | IB4010 1 0.000 2960.554688 10 | IB4011 1 0.000 2416.981375 11 | IS1008a 1 0.000 943.833313 12 | IS1008b 1 0.000 1768.500000 13 | IS1008c 1 0.000 1546.333312 14 | IS1008d 1 0.000 1480.833312 15 | TS3004a 1 0.000 1345.322625 16 | TS3004b 1 0.000 2246.058625 17 | TS3004c 1 0.000 2970.000000 18 | TS3004d 1 0.000 2750.800000 19 | R8001_M8004_MS801 1 0.000 1573.85 20 | R8003_M8001_MS801 1 0.000 2068 21 | R8007_M8010_MS803 1 0.000 1856.322 22 | R8007_M8011_MS806 1 0.000 1861.546 23 | R8008_M8013_MS807 1 0.000 2239.47 24 | R8009_M8018_MS809 1 0.000 1654.657 25 | R8009_M8019_MS810 1 0.000 1973.926 26 | R8009_M8020_MS810 1 0.000 1910.443 27 | 20200706_L_R001S01C01 1 0.000 2011.088 28 | 20200708_L_R002S07C01 1 0.000 1939.912 29 | 20200709_L_R002S08C01 1 0.000 2040.388 30 | 20200616_M_R001S01C01 1 0.000 1868.833 31 | 20200715_M_R002S06C01 1 0.000 1915.477 32 | 20200620_M_R002S08C01 1 0.000 2258.16 33 | 20200620_M_R002S07C01 1 0.000 2196.459 34 | 20200710_M_R002S03C01 1 0.000 1900.796 35 | 20200710_M_R002S07C01 1 0.000 1794.738 36 | 20200712_M_R002S05C01 1 0.000 1875.699 37 | 20200704_M_R002S06C01 1 0.000 2154.912 38 | 20200622_M_R002S02C01 1 0.000 2236.539 39 | 20200715_M_R002S03C01 1 0.000 1929.546 40 | 20200623_S_R001S06C01 1 0.000 2087.362 41 | 20200805_S_R001S01C01 1 0.000 2195.362 42 | 20200701_S_R001S03C01 1 0.000 2273.842 43 | 20200805_S_R001S03C01 1 0.000 2224.031 44 | 20200702_S_R001S02C01 1 0.000 2209.856 45 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2022 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .multilabel import MultiLabelSegmentation 24 | from .overlapped_speech_detection import OverlappedSpeechDetection 25 | from .resegmentation import Resegmentation 26 | from .speaker_diarization import SpeakerDiarization 27 | from .voice_activity_detection import VoiceActivityDetection 28 | 29 | __all__ = [ 30 | "VoiceActivityDetection", 31 | "OverlappedSpeechDetection", 32 | "SpeakerDiarization", 33 | "Resegmentation", 34 | "MultiLabelSegmentation", 35 | ] 36 | -------------------------------------------------------------------------------- /recipes/diar_ssl/conf/pyannote_baseline.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_single_opt.Trainer" 10 | [trainer.args] 11 | max_epochs = 100 12 | gradient_percentile = 90 13 | gradient_history_size = 1000 14 | save_max_score = false 15 | save_ckpt_interval = 1 16 | max_patience = 10 17 | max_num_checkpoints = 100 18 | gradient_accumulation_steps = 1 19 | validation_interval = 1 20 | freeze_wavlm = false 21 | lr_decay = false 22 | use_one_cycle_lr = false 23 | 24 | [optimizer] 25 | path = "torch.optim.AdamW" 26 | [optimizer.args] 27 | lr = 1e-3 28 | 29 | [model] 30 | path = "diarizen.models.eend.model_pyannote.Model" 31 | [model.args] 32 | chunk_size = 8 33 | selected_channel = 0 34 | max_speakers_per_chunk = 4 35 | 36 | [train_dataset] 37 | path = "dataset.DiarizationDataset" 38 | [train_dataset.args] 39 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 40 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 41 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 42 | chunk_size = 8 43 | chunk_shift = 6 44 | sample_rate = 16000 45 | 46 | [train_dataset.dataloader] 47 | batch_size = 32 48 | num_workers = 1 49 | drop_last = true 50 | pin_memory = true 51 | 52 | [validate_dataset] 53 | path = "dataset.DiarizationDataset" 54 | [validate_dataset.args] 55 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 57 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 58 | chunk_size = 8 59 | chunk_shift = 8 60 | sample_rate = 16000 61 | 62 | [validate_dataset.dataloader] 63 | batch_size = 16 64 | num_workers = 1 65 | drop_last = true 66 | pin_memory = true 67 | 68 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/pipelines/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .diarization import SpeakerDiarizationMixin 24 | from .getter import ( 25 | PipelineAugmentation, 26 | PipelineInference, 27 | PipelineModel, 28 | get_augmentation, 29 | get_devices, 30 | get_inference, 31 | get_model, 32 | ) 33 | from .oracle import oracle_segmentation 34 | 35 | __all__ = [ 36 | "SpeakerDiarizationMixin", 37 | "oracle_segmentation", 38 | "get_augmentation", 39 | "PipelineAugmentation", 40 | "get_devices", 41 | "get_inference", 42 | "PipelineInference", 43 | "get_model", 44 | "PipelineModel", 45 | ] 46 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/test/AliMeeting/wav.scp: -------------------------------------------------------------------------------- 1 | R8002_M8002_MS802 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8002_M8002_MS802.wav 2 | R8002_M8003_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8002_M8003_MS803.wav 3 | R8004_M8005_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8004_M8005_MS803.wav 4 | R8004_M8006_MS805 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8004_M8006_MS805.wav 5 | R8005_M8007_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8007_MS806.wav 6 | R8005_M8008_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8008_MS806.wav 7 | R8005_M8009_MS802 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8005_M8009_MS802.wav 8 | R8006_M8012_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8006_M8012_MS803.wav 9 | R8008_M8014_MS807 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8014_MS807.wav 10 | R8008_M8015_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8015_MS808.wav 11 | R8008_M8016_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8016_MS808.wav 12 | R8008_M8017_MS808 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8008_M8017_MS808.wav 13 | R8009_M8021_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8021_MS810.wav 14 | R8009_M8022_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8022_MS810.wav 15 | R8009_M8023_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8023_MS811.wav 16 | R8009_M8024_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8024_MS811.wav 17 | R8009_M8025_MS811 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8025_MS811.wav 18 | R8009_M8026_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8026_MS812.wav 19 | R8009_M8027_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8027_MS812.wav 20 | R8009_M8028_MS812 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/test/R8009_M8028_MS812.wav 21 | -------------------------------------------------------------------------------- /pyannote-audio/tests/utils/probe_util_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from pyannote.audio.utils.probe import probe 5 | 6 | 7 | class Trunk(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.layer1 = nn.Linear(1, 2) 11 | self.layer2 = nn.Linear(2, 3) 12 | self.layer3 = nn.Linear(3, 4) 13 | 14 | def forward(self, x): 15 | return self.layer3(self.layer2(self.layer1(x))) 16 | 17 | 18 | def test_probe_dict(): 19 | trunk = Trunk() 20 | probe(trunk, {"probe1": "layer1"}) 21 | out = trunk( 22 | torch.ones( 23 | 1, 24 | ) 25 | ) 26 | assert isinstance(out, dict) 27 | assert len(out.keys()) == 1 28 | assert isinstance(out["probe1"], torch.Tensor) 29 | 30 | 31 | def test_probe_output(): 32 | trunk = Trunk() 33 | probe(trunk, {"probe1": "layer3"}) 34 | out = trunk( 35 | torch.ones( 36 | 1, 37 | ) 38 | ) 39 | out = out["probe1"] 40 | tout = trunk.layer3( 41 | trunk.layer2( 42 | trunk.layer1( 43 | torch.ones( 44 | 1, 45 | ) 46 | ) 47 | ) 48 | ) 49 | assert torch.equal(tout, out) 50 | 51 | 52 | def test_probe_revert(): 53 | trunk = Trunk() 54 | revert = probe(trunk, {"probe1": "layer3"}) 55 | out = trunk( 56 | torch.ones( 57 | 1, 58 | ) 59 | ) 60 | assert isinstance(out, dict) 61 | revert() 62 | out = trunk( 63 | torch.ones( 64 | 1, 65 | ) 66 | ) 67 | assert isinstance(out, torch.Tensor) 68 | 69 | 70 | def test_probe_array(): 71 | trunk = Trunk() 72 | probe(trunk, ["layer3"]) 73 | out = trunk( 74 | torch.ones( 75 | 1, 76 | ) 77 | ) 78 | assert isinstance(out, dict) 79 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .audio.diarization_error_rate import ( 25 | DiarizationErrorRate, 26 | FalseAlarmRate, 27 | MissedDetectionRate, 28 | OptimalDiarizationErrorRate, 29 | OptimalDiarizationErrorRateThreshold, 30 | OptimalFalseAlarmRate, 31 | OptimalMissedDetectionRate, 32 | OptimalSpeakerConfusionRate, 33 | SpeakerConfusionRate, 34 | ) 35 | 36 | __all__ = [ 37 | "DiarizationErrorRate", 38 | "FalseAlarmRate", 39 | "MissedDetectionRate", 40 | "SpeakerConfusionRate", 41 | "OptimalDiarizationErrorRate", 42 | "OptimalFalseAlarmRate", 43 | "OptimalMissedDetectionRate", 44 | "OptimalSpeakerConfusionRate", 45 | "OptimalDiarizationErrorRateThreshold", 46 | ] 47 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/audio/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from .diarization_error_rate import ( 25 | DiarizationErrorRate, 26 | FalseAlarmRate, 27 | MissedDetectionRate, 28 | OptimalDiarizationErrorRate, 29 | OptimalDiarizationErrorRateThreshold, 30 | OptimalFalseAlarmRate, 31 | OptimalMissedDetectionRate, 32 | OptimalSpeakerConfusionRate, 33 | SpeakerConfusionRate, 34 | ) 35 | 36 | __all__ = [ 37 | "DiarizationErrorRate", 38 | "SpeakerConfusionRate", 39 | "MissedDetectionRate", 40 | "FalseAlarmRate", 41 | "OptimalDiarizationErrorRate", 42 | "OptimalSpeakerConfusionRate", 43 | "OptimalMissedDetectionRate", 44 | "OptimalFalseAlarmRate", 45 | "OptimalDiarizationErrorRateThreshold", 46 | ] 47 | -------------------------------------------------------------------------------- /pyannote-audio/tutorials/speaker_verification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "#### Speaker verification\n", 9 | "\n", 10 | "```python\n", 11 | "import torch\n", 12 | "from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding\n", 13 | "model = PretrainedSpeakerEmbedding(\n", 14 | " \"speechbrain/spkrec-ecapa-voxceleb\",\n", 15 | " device=torch.device(\"cuda\"))\n", 16 | "\n", 17 | "from pyannote.audio import Audio\n", 18 | "from pyannote.core import Segment\n", 19 | "audio = Audio(sample_rate=16000, mono=\"downmix\")\n", 20 | "\n", 21 | "# extract embedding for a speaker speaking between t=3s and t=6s\n", 22 | "speaker1 = Segment(3., 6.)\n", 23 | "waveform1, sample_rate = audio.crop(\"audio.wav\", speaker1)\n", 24 | "embedding1 = model(waveform1[None])\n", 25 | "\n", 26 | "# extract embedding for a speaker speaking between t=7s and t=12s\n", 27 | "speaker2 = Segment(7., 12.)\n", 28 | "waveform2, sample_rate = audio.crop(\"audio.wav\", speaker2)\n", 29 | "embedding2 = model(waveform2[None])\n", 30 | "\n", 31 | "# compare embeddings using \"cosine\" distance\n", 32 | "from scipy.spatial.distance import cdist\n", 33 | "distance = cdist(embedding1, embedding2, metric=\"cosine\")\n", 34 | "```\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [] 41 | } 42 | ], 43 | "metadata": { 44 | "interpreter": { 45 | "hash": "41379f2c2a4eb17f5ac9a1f5014f4b793a0ead0b6469d8877f81a91eb030f53e" 46 | }, 47 | "kernelspec": { 48 | "display_name": "Python 3.8.2 64-bit ('pyannote': conda)", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "name": "python", 54 | "version": "3.8.2" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | -------------------------------------------------------------------------------- /recipes/diar_ssl/conf/wavlm_frozen_conformer.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_single_opt.Trainer" 10 | [trainer.args] 11 | max_epochs = 100 12 | gradient_percentile = 90 13 | gradient_history_size = 1000 14 | save_max_score = false 15 | save_ckpt_interval = 1 16 | max_patience = 10 17 | max_num_checkpoints = 100 18 | gradient_accumulation_steps = 1 19 | validation_interval = 1 20 | freeze_wavlm = true 21 | lr_decay = false 22 | use_one_cycle_lr = false 23 | 24 | [optimizer] 25 | path = "torch.optim.AdamW" 26 | [optimizer.args] 27 | lr = 1e-3 28 | 29 | [model] 30 | path = "diarizen.models.eend.model_wavlm_conformer.Model" 31 | [model.args] 32 | wavlm_src = "/YOUR_PATH/WavLM-Base+.pt" 33 | wavlm_layer_num = 13 34 | wavlm_feat_dim = 768 35 | attention_in = 256 36 | ffn_hidden = 1024 37 | num_head = 4 38 | num_layer = 4 39 | dropout = 0.1 40 | chunk_size = 8 41 | use_posi = false 42 | output_activate_function = false 43 | selected_channel = 0 44 | max_speakers_per_chunk = 4 45 | 46 | [train_dataset] 47 | path = "dataset.DiarizationDataset" 48 | [train_dataset.args] 49 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 50 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 51 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 52 | chunk_size = 8 53 | chunk_shift = 6 54 | sample_rate = 16000 55 | 56 | [train_dataset.dataloader] 57 | batch_size = 32 58 | num_workers = 1 59 | drop_last = true 60 | pin_memory = true 61 | 62 | [validate_dataset] 63 | path = "dataset.DiarizationDataset" 64 | [validate_dataset.args] 65 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 66 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 67 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 68 | chunk_size = 8 69 | chunk_shift = 8 70 | sample_rate = 16000 71 | 72 | [validate_dataset.dataloader] 73 | batch_size = 8 74 | num_workers = 1 75 | drop_last = true 76 | pin_memory = true 77 | 78 | -------------------------------------------------------------------------------- /recipes/diar_ssl/conf/fbank_conformer.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_single_opt.Trainer" 10 | [trainer.args] 11 | max_epochs = 100 12 | gradient_percentile = 90 13 | gradient_history_size = 1000 14 | save_max_score = false 15 | save_ckpt_interval = 1 16 | max_patience = 10 17 | max_num_checkpoints = 100 18 | gradient_accumulation_steps = 1 19 | validation_interval = 1 20 | freeze_wavlm = false 21 | lr_decay = false 22 | use_one_cycle_lr = false 23 | 24 | [optimizer] 25 | path = "torch.optim.AdamW" 26 | [optimizer.args] 27 | lr = 1e-3 28 | 29 | [model] 30 | path = "diarizen.models.eend.model_fbank_conformer.Model" 31 | [model.args] 32 | n_fft = 400 33 | n_mels = 80 34 | win_length = 25 # ms 35 | hop_length = 10 # ms 36 | sample_rate = 16000 37 | attention_in = 256 38 | ffn_hidden = 1024 39 | num_head = 4 40 | num_layer = 4 41 | dropout = 0.1 42 | chunk_size = 8 43 | use_posi = false 44 | output_activate_function = false 45 | selected_channel = 0 46 | max_speakers_per_chunk = 4 47 | 48 | [train_dataset] 49 | path = "dataset.DiarizationDataset" 50 | [train_dataset.args] 51 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 52 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 53 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 54 | chunk_size = 8 55 | chunk_shift = 6 56 | sample_rate = 16000 57 | 58 | [train_dataset.dataloader] 59 | batch_size = 16 60 | num_workers = 1 61 | drop_last = true 62 | pin_memory = true 63 | 64 | [validate_dataset] 65 | path = "dataset.DiarizationDataset" 66 | [validate_dataset.args] 67 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 68 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 69 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 70 | chunk_size = 8 71 | chunk_shift = 8 72 | sample_rate = 16000 73 | 74 | [validate_dataset.dataloader] 75 | batch_size = 16 76 | num_workers = 1 77 | drop_last = true 78 | pin_memory = true 79 | 80 | -------------------------------------------------------------------------------- /pyannote-audio/tests/tasks/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lightning.pytorch import seed_everything 3 | from pyannote.database import FileFinder, get_protocol 4 | 5 | from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel 6 | from pyannote.audio.tasks import VoiceActivityDetection 7 | 8 | 9 | def setup_tasks(task): 10 | protocol = get_protocol( 11 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()} 12 | ) 13 | vad = task(protocol, duration=0.2, batch_size=32, num_workers=4) 14 | return protocol, vad 15 | 16 | 17 | def create_dl(model, task): 18 | m = model(task=task) 19 | m.prepare_data() 20 | m.setup() 21 | return task.train_dataloader() 22 | 23 | 24 | def get_next5(dl): 25 | last5 = [] 26 | it = iter(dl) 27 | for i in range(5): 28 | last5.append(next(it)) 29 | return last5 30 | 31 | 32 | def test_seeding_ensures_data_loaders(): 33 | "Setting a global seed for the dataloaders ensures that we get data back in the same order" 34 | 35 | seed_everything(1) 36 | protocol, vad = setup_tasks(VoiceActivityDetection) 37 | dl = create_dl(SimpleSegmentationModel, vad) 38 | last5a = get_next5(dl) 39 | 40 | seed_everything(1) 41 | protocol, vad = setup_tasks(VoiceActivityDetection) 42 | dl = create_dl(SimpleSegmentationModel, vad) 43 | last5b = get_next5(dl) 44 | 45 | for i in range(len(last5b)): 46 | assert torch.equal(last5a[i]["X"], last5b[i]["X"]) 47 | 48 | 49 | def test_different_seeds(): 50 | "Changing the global seed will change the order of the data that loads" 51 | 52 | protocol, vad = setup_tasks(VoiceActivityDetection) 53 | seed_everything(4) 54 | dl = create_dl(SimpleSegmentationModel, vad) 55 | last5a = get_next5(dl) 56 | 57 | protocol, vad = setup_tasks(VoiceActivityDetection) 58 | seed_everything(5) 59 | dl = create_dl(SimpleSegmentationModel, vad) 60 | last5b = get_next5(dl) 61 | 62 | for i in range(5): 63 | assert not torch.equal(last5a[i]["X"], last5b[i]["X"]) 64 | -------------------------------------------------------------------------------- /recipes/diar_ssl/conf/wavlm_updated_conformer.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_dual_opt.Trainer" 10 | [trainer.args] 11 | max_epochs = 100 12 | gradient_percentile = 90 13 | gradient_history_size = 1000 14 | save_max_score = false 15 | save_ckpt_interval = 1 16 | max_patience = 10 17 | max_num_checkpoints = 100 18 | gradient_accumulation_steps = 1 19 | validation_interval = 1 20 | freeze_wavlm = false 21 | lr_decay = false 22 | use_one_cycle_lr = false 23 | 24 | [optimizer_small] 25 | path = "torch.optim.AdamW" 26 | [optimizer_small.args] 27 | lr = 2e-5 28 | 29 | [optimizer_big] 30 | path = "torch.optim.AdamW" 31 | [optimizer_big.args] 32 | lr = 1e-3 33 | 34 | [model] 35 | path = "diarizen.models.eend.model_wavlm_conformer.Model" 36 | [model.args] 37 | wavlm_src = "/YOUR_PATH/WavLM-Base+.pt" 38 | wavlm_layer_num = 13 39 | wavlm_feat_dim = 768 40 | attention_in = 256 41 | ffn_hidden = 1024 42 | num_head = 4 43 | num_layer = 4 44 | dropout = 0.1 45 | chunk_size = 8 46 | use_posi = false 47 | output_activate_function = false 48 | selected_channel = 0 49 | max_speakers_per_chunk = 4 50 | 51 | [train_dataset] 52 | path = "dataset.DiarizationDataset" 53 | [train_dataset.args] 54 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 55 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 56 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 57 | chunk_size = 8 58 | chunk_shift = 6 59 | sample_rate = 16000 60 | 61 | [train_dataset.dataloader] 62 | batch_size = 16 63 | num_workers = 1 64 | drop_last = true 65 | pin_memory = true 66 | 67 | [validate_dataset] 68 | path = "dataset.DiarizationDataset" 69 | [validate_dataset.args] 70 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 71 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 72 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 73 | chunk_size = 8 74 | chunk_shift = 8 75 | sample_rate = 16000 76 | 77 | [validate_dataset.dataloader] 78 | batch_size = 8 79 | num_workers = 1 80 | drop_last = true 81 | pin_memory = true 82 | 83 | -------------------------------------------------------------------------------- /diarizen/models/module/wav2vec2/pruning_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for pruning.""" 2 | 3 | from typing import Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): 10 | "Prune linear layer in place." 11 | # NOTE: weight: (out_features, in_features), bias: (out_features,) 12 | if dim == "input": 13 | dim = 1 14 | layer.in_features = len(index) 15 | elif dim == "output": 16 | dim = 0 17 | layer.out_features = len(index) 18 | else: 19 | raise ValueError 20 | 21 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 22 | if layer.bias is not None and dim == 0: 23 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 24 | 25 | 26 | def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): 27 | """Prune conv1d in place.""" 28 | # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) 29 | if dim == "input": 30 | dim = 1 31 | layer.in_channels = len(index) 32 | elif dim == "output": 33 | dim = 0 34 | layer.out_channels = len(index) 35 | else: 36 | raise ValueError 37 | 38 | layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) 39 | if layer.bias is not None and dim == 0: 40 | layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) 41 | 42 | 43 | def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): 44 | """Prune layer norm or group norm in place.""" 45 | layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) 46 | layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) 47 | if isinstance(layernorm, nn.LayerNorm): 48 | layernorm.normalized_shape = (len(index),) 49 | elif isinstance(layernorm, nn.GroupNorm): 50 | layernorm.num_groups = len(index) 51 | layernorm.num_channels = len(index) 52 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/conf/dual_opt_common.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_dual_opt.Trainer" 10 | [trainer.args] 11 | debug = false 12 | max_epochs = 20 13 | gradient_percentile = 90 14 | gradient_history_size = 1000 15 | save_max_score = false 16 | save_ckpt_interval = 1 17 | max_patience = 5 18 | max_num_checkpoints = 100 19 | gradient_accumulation_steps = 1 20 | validation_interval = 1 21 | freeze_wavlm = false 22 | lr_decay = false 23 | use_one_cycle_lr = false 24 | 25 | [optimizer_small] 26 | path = "torch.optim.AdamW" 27 | [optimizer_small.args] 28 | lr = 2e-5 29 | 30 | [optimizer_big] 31 | path = "torch.optim.AdamW" 32 | [optimizer_big.args] 33 | lr = 1e-3 34 | 35 | [model] 36 | path = "diarizen.models.eend.model_wavlm_conformer.Model" 37 | [model.args] 38 | wavlm_src = "/YOUR_PATH/wavlm-base-plus-finetuned.bin" 39 | wavlm_layer_num = 13 40 | wavlm_feat_dim = 768 41 | attention_in = 256 42 | ffn_hidden = 1024 43 | num_head = 4 44 | num_layer = 4 45 | dropout = 0.1 46 | chunk_size = 8 47 | use_posi = false 48 | output_activate_function = false 49 | selected_channel = 0 50 | max_speakers_per_chunk = 4 51 | 52 | [train_dataset] 53 | path = "dataset.DiarizationDataset" 54 | [train_dataset.args] 55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 58 | chunk_size = 8 59 | chunk_shift = 6 60 | sample_rate = 16000 61 | 62 | [train_dataset.dataloader] 63 | batch_size = 16 64 | num_workers = 1 65 | drop_last = true 66 | pin_memory = true 67 | 68 | [validate_dataset] 69 | path = "dataset.DiarizationDataset" 70 | [validate_dataset.args] 71 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 72 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 73 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 74 | chunk_size = 8 75 | chunk_shift = 8 76 | sample_rate = 16000 77 | 78 | [validate_dataset.dataloader] 79 | batch_size = 16 80 | num_workers = 1 81 | drop_last = true 82 | pin_memory = true 83 | 84 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/conf/dual_opt_common_large.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_dual_opt.Trainer" 10 | [trainer.args] 11 | debug = false 12 | max_epochs = 20 13 | gradient_percentile = 90 14 | gradient_history_size = 1000 15 | save_max_score = false 16 | save_ckpt_interval = 1 17 | max_patience = 5 18 | max_num_checkpoints = 100 19 | gradient_accumulation_steps = 1 20 | validation_interval = 1 21 | freeze_wavlm = false 22 | lr_decay = false 23 | use_one_cycle_lr = false 24 | 25 | [optimizer_small] 26 | path = "torch.optim.AdamW" 27 | [optimizer_small.args] 28 | lr = 2e-5 29 | 30 | [optimizer_big] 31 | path = "torch.optim.AdamW" 32 | [optimizer_big.args] 33 | lr = 1e-3 34 | 35 | [model] 36 | path = "diarizen.models.eend.model_wavlm_conformer.Model" 37 | [model.args] 38 | wavlm_src = "/YOUR_PATH/wavlm-large-finetuned.bin" 39 | wavlm_layer_num = 25 40 | wavlm_feat_dim = 1024 41 | attention_in = 256 42 | ffn_hidden = 1024 43 | num_head = 4 44 | num_layer = 4 45 | dropout = 0.1 46 | chunk_size = 8 47 | use_posi = false 48 | output_activate_function = false 49 | selected_channel = 0 50 | max_speakers_per_chunk = 4 51 | 52 | [train_dataset] 53 | path = "dataset.DiarizationDataset" 54 | [train_dataset.args] 55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 58 | chunk_size = 8 59 | chunk_shift = 6 60 | sample_rate = 16000 61 | 62 | [train_dataset.dataloader] 63 | batch_size = 16 64 | num_workers = 1 65 | drop_last = true 66 | pin_memory = true 67 | 68 | [validate_dataset] 69 | path = "dataset.DiarizationDataset" 70 | [validate_dataset.args] 71 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 72 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 73 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 74 | chunk_size = 8 75 | chunk_shift = 8 76 | sample_rate = 16000 77 | 78 | [validate_dataset.dataloader] 79 | batch_size = 16 80 | num_workers = 1 81 | drop_last = true 82 | pin_memory = true 83 | 84 | -------------------------------------------------------------------------------- /diarizen/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/trainer_utils.py 3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com) 4 | 5 | import numpy as np 6 | import torch 7 | from accelerate.utils import set_seed 8 | 9 | 10 | def seed_worker(_): 11 | """Helper function to set worker seed during Dataloader initialization. 12 | 13 | In recent check-ins, we may have no longer needed this function because PyTorch has already set the worker seed 14 | for numpy and random. But there is no adverse effect to keeping this function, since the initial_seed is 15 | inner_seed + worker_ids. 16 | """ 17 | worker_seed = torch.initial_seed() % 2**32 18 | set_seed(worker_seed) 19 | 20 | 21 | def has_length(dataset): 22 | """ 23 | Checks if the dataset implements __len__() and it doesn't raise an error 24 | """ 25 | try: 26 | return len(dataset) is not None 27 | except TypeError: 28 | # TypeError: len() of unsized object 29 | return False 30 | 31 | 32 | class TrainerState: 33 | def __init__(self, save_max_score) -> None: 34 | self.epochs_trained = 0 35 | self.steps_trained = 0 36 | 37 | self.patience = 0 38 | 39 | self.best_score = -np.inf if save_max_score else np.inf 40 | self.best_score_epoch = 0 41 | 42 | def load_state_dict(self, state_dict: dict) -> None: 43 | self.epochs_trained = state_dict["epochs_trained"] 44 | self.steps_trained = state_dict["steps_trained"] 45 | 46 | self.best_score = state_dict["best_score"] 47 | self.best_score_epoch = state_dict["best_score_epoch"] 48 | 49 | self.patience = state_dict["patience"] 50 | 51 | def state_dict(self) -> dict: 52 | return { 53 | "epochs_trained": self.epochs_trained, 54 | "steps_trained": self.steps_trained, 55 | "patience": self.patience, 56 | "best_score": self.best_score, 57 | "best_score_epoch": self.best_score_epoch, 58 | } 59 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020-2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .segmentation.multilabel import MultiLabelSegmentation # isort:skip 24 | from .segmentation.speaker_diarization import SpeakerDiarization # isort:skip 25 | from .segmentation.voice_activity_detection import VoiceActivityDetection # isort:skip 26 | from .segmentation.overlapped_speech_detection import ( # isort:skip 27 | OverlappedSpeechDetection, 28 | ) 29 | from .embedding.arcface import SupervisedRepresentationLearningWithArcFace # isort:skip 30 | 31 | # Segmentation has been renamed to SpeakerDiarization but we keep Segmentation here for backward compatibility 32 | Segmentation = SpeakerDiarization 33 | 34 | # SpeakerEmbedding is more human-friendly 35 | SpeakerEmbedding = SupervisedRepresentationLearningWithArcFace 36 | 37 | __all__ = [ 38 | "SpeakerDiarization", 39 | "VoiceActivityDetection", 40 | "OverlappedSpeechDetection", 41 | "MultiLabelSegmentation", 42 | "SpeakerEmbedding", 43 | "Segmentation", 44 | ] 45 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/sample/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2024- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from pathlib import Path 25 | 26 | from pyannote.core import Annotation, Segment, Timeline 27 | from pyannote.database.util import load_rttm 28 | 29 | from pyannote.audio.core.io import Audio, AudioFile 30 | 31 | 32 | def _sample() -> AudioFile: 33 | sample_wav = Path(__file__).parent / "sample.wav" 34 | uri = "sample" 35 | 36 | audio = Audio() 37 | waveform, sample_rate = audio(sample_wav) 38 | 39 | sample_rttm = Path(__file__).parent / "sample.rttm" 40 | 41 | annotation: Annotation = load_rttm(sample_rttm)[uri] 42 | duration = audio.get_duration(sample_wav) 43 | 44 | annotated: Timeline = Timeline([Segment(0.0, duration)], uri=uri) 45 | 46 | return { 47 | "audio": sample_wav, 48 | "uri": "sample", 49 | "waveform": waveform, 50 | "sample_rate": sample_rate, 51 | "annotation": annotation, 52 | "annotated": annotated, 53 | } 54 | 55 | 56 | SAMPLE_FILE = _sample() 57 | -------------------------------------------------------------------------------- /pyannote-audio/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | 5 | from pkg_resources import VersionConflict, require 6 | from setuptools import find_packages, setup 7 | 8 | with open("README.md", mode="r", encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | with open("requirements.txt", mode="r", encoding="utf-8") as f: 12 | requirements = f.read().splitlines() 13 | 14 | try: 15 | require("setuptools>=38.3") 16 | except VersionConflict: 17 | print("Error: version of setuptools is too old (<38.3)!") 18 | sys.exit(1) 19 | 20 | 21 | ROOT_DIR = Path(__file__).parent.resolve() 22 | # Creating the version file 23 | 24 | with open("version.txt", mode="r", encoding="utf-8") as f: 25 | version = f.read() 26 | 27 | version = version.strip() 28 | sha = "Unknown" 29 | 30 | if os.getenv("BUILD_VERSION"): 31 | version = os.getenv("BUILD_VERSION") 32 | elif sha != "Unknown": 33 | version += "+" + sha[:7] 34 | print("-- Building version " + version) 35 | 36 | version_path = ROOT_DIR / "pyannote" / "audio" / "version.py" 37 | 38 | with open(version_path, mode="w", encoding="utf-8") as f: 39 | f.write("__version__ = '{}'\n".format(version)) 40 | 41 | if __name__ == "__main__": 42 | setup( 43 | name="pyannote.audio", 44 | namespace_packages=["pyannote"], 45 | version=version, 46 | packages=find_packages(), 47 | install_requires=requirements, 48 | description="Neural building blocks for speaker diarization", 49 | long_description=long_description, 50 | long_description_content_type="text/markdown", 51 | author="Hervé Bredin", 52 | author_email="herve.bredin@irit.fr", 53 | url="https://github.com/pyannote/pyannote-audio", 54 | classifiers=[ 55 | "Development Status :: 4 - Beta", 56 | "Intended Audience :: Science/Research", 57 | "License :: OSI Approved :: MIT License", 58 | "Natural Language :: English", 59 | "Programming Language :: Python :: 3.8", 60 | "Programming Language :: Python :: 3.9", 61 | "Programming Language :: Python :: 3.10", 62 | "Topic :: Scientific/Engineering", 63 | ], 64 | ) 65 | -------------------------------------------------------------------------------- /diarizen/noam_updater.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Licensed under the MIT license. 3 | # Copyright 2022 Brno University of Technology (author: Federico Landini, landini@fit.vut.cz) 4 | 5 | from math import sqrt 6 | 7 | import torch.optim as optim 8 | from typing import Any, Dict 9 | 10 | class NoamOpt: 11 | "Optim wrapper that implements rate." 12 | def __init__(self, model_size: int, warmup: int, optimizer: optim) -> None: 13 | self.optimizer = optimizer 14 | self._step = 0 15 | self.warmup = warmup 16 | self.model_size = model_size 17 | self._rate = 0 18 | 19 | def state_dict(self) -> Dict[str, Any]: 20 | """Returns the state of the warmup scheduler as a :class:`dict`. 21 | It contains an entry for every variable in self.__dict__ which 22 | is not the optimizer. 23 | """ 24 | return { 25 | key: value 26 | for key, value in self.__dict__.items() if key != 'optimizer'} 27 | 28 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 29 | """Loads the warmup scheduler's state. 30 | Arguments: 31 | state_dict (dict): warmup scheduler state. 32 | Should be an object returned from a call to :meth:`state_dict`. 33 | """ 34 | self.__dict__.update(state_dict) 35 | 36 | def step(self) -> None: 37 | "Update parameters and rate" 38 | self._step += 1 39 | rate = self.rate() 40 | for p in self.optimizer.param_groups: 41 | p['lr'] = rate 42 | self._rate = rate 43 | self.optimizer.step() 44 | 45 | def rate(self, step: int = None) -> float: 46 | "Implement `lrate` above" 47 | if step is None: 48 | step = self._step 49 | return ( 50 | self.model_size ** (-0.5) * 51 | min(step ** (-0.5), step * self.warmup ** (-1.5))) 52 | 53 | def get_rate(self) -> float: 54 | return self._rate 55 | 56 | def zero_grad(self) -> None: 57 | self.optimizer.zero_grad() 58 | 59 | 60 | def get_rate(optimizer: optim) -> float: 61 | if isinstance(optimizer, NoamOpt): 62 | return optimizer.get_rate() 63 | else: 64 | for param_group in optimizer.param_groups: 65 | return param_group['lr'] -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/version.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Text 24 | 25 | from semver import VersionInfo 26 | 27 | 28 | def check_version(library: Text, theirs: Text, mine: Text, what: Text = "Pipeline"): 29 | 30 | theirs = ".".join(theirs.split(".")[:3]) 31 | mine = ".".join(mine.split(".")[:3]) 32 | 33 | theirs = VersionInfo.parse(theirs) 34 | mine = VersionInfo.parse(mine) 35 | 36 | if theirs.major > mine.major: 37 | print( 38 | f"{what} was trained with {library} {theirs}, yours is {mine}. " 39 | f"Bad things will probably happen unless you upgrade {library} to {theirs.major}.x." 40 | ) 41 | 42 | elif theirs.major < mine.major: 43 | print( 44 | f"{what} was trained with {library} {theirs}, yours is {mine}. " 45 | f"Bad things might happen unless you revert {library} to {theirs.major}.x." 46 | ) 47 | 48 | elif theirs.minor > mine.minor: 49 | print( 50 | f"{what} was trained with {library} {theirs}, yours is {mine}. " 51 | f"This should be OK but you might want to upgrade {library}." 52 | ) 53 | -------------------------------------------------------------------------------- /pyannote-audio/notebook/sharing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pyannote.database import get_protocol, FileFinder\n", 10 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n", 11 | " preprocessors={\"audio\": FileFinder()})" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Train a model" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from pyannote.audio.tasks import VoiceActivityDetection\n", 28 | "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n", 29 | "import pytorch_lightning as pl\n", 30 | "\n", 31 | "vad = VoiceActivityDetection(protocol, duration=2., batch_size=32, num_workers=4)\n", 32 | "model = SimpleSegmentationModel(task=vad)\n", 33 | "trainer = pl.Trainer(max_epochs=1, default_root_dir='sharing/')\n", 34 | "_ = trainer.fit(model)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Load a model without knowing its class" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from pyannote.audio import Model\n", 51 | "model = Model.from_pretrained('sharing/lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt')\n", 52 | "assert isinstance(model, SimpleSegmentationModel)\n", 53 | "\n", 54 | "# checkpoint should work with a URL as well (it relies on pl_load)" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "Python 3", 61 | "language": "python", 62 | "name": "python3" 63 | }, 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.8.5" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 4 79 | } 80 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/random.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | import os 25 | import zlib 26 | from random import Random 27 | 28 | import torch 29 | 30 | 31 | def create_rng_for_worker(model) -> Random: 32 | """Create worker-specific random number generator 33 | 34 | This makes sure that 35 | 1. training samples generation is reproducible 36 | 2. every (worker, epoch) uses a different seed 37 | 38 | Parameters 39 | ---------- 40 | epoch : int 41 | Current epoch. 42 | """ 43 | 44 | # create random number generator 45 | rng = Random() 46 | 47 | global_seed = os.environ.get("PL_GLOBAL_SEED", "unset") 48 | worker_info = torch.utils.data.get_worker_info() 49 | 50 | if worker_info is None: 51 | worker_id = None 52 | else: 53 | worker_id = worker_info.id 54 | 55 | seed_tuple = ( 56 | global_seed, 57 | worker_id, 58 | model.local_rank, 59 | model.global_rank, 60 | model.current_epoch, 61 | ) 62 | # use adler32 because python's `hash` is not deterministic. 63 | seed = zlib.adler32(str(seed_tuple).encode()) 64 | rng.seed(seed) 65 | 66 | return rng 67 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/conf/s80_base.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_distill_prune.Trainer" 10 | [trainer.args] 11 | debug = false 12 | max_epochs = 30 13 | gradient_percentile = 90 14 | gradient_history_size = 1000 15 | save_max_score = false 16 | save_ckpt_interval = 1 17 | max_patience = 5 18 | max_num_checkpoints = 100 19 | gradient_accumulation_steps = 1 20 | validation_interval = 1 21 | freeze_wavlm = false 22 | lr_decay = false 23 | use_one_cycle_lr = false 24 | # prune 25 | use_reg = true 26 | target_sparsity = 0.8 27 | pre_train_epochs = 0 28 | sparsity_warmup_epochs = 5 29 | 30 | [optimizer] 31 | path = "torch.optim.AdamW" 32 | [optimizer.args] 33 | distill_lr = 2e-4 34 | reg_lr = 2e-2 35 | 36 | [model] 37 | path = "diarizen.models.pruning.model_distill_prune.Model" 38 | [model.args] 39 | teacher_ckpt = "/YOUR_PATH/wavlm-base-plus-finetuned.bin" 40 | student_ckpt = "/YOUR_PATH/wavlm-base-plus-finetuned.bin" 41 | pruning_units = "conv,head,interm" 42 | distill_layers = "0,4,8,12" 43 | 44 | [distill_loss] 45 | path = "diarizen.models.pruning.utils.DistillLoss" 46 | [distill_loss.args] 47 | l2_weight = 0 48 | l1_weight = 1 49 | cos_weight = 1 50 | cos_type = "raw" 51 | 52 | [train_dataset] 53 | path = "dataset.DiarizationDataset" 54 | [train_dataset.args] 55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 58 | chunk_size = 8 59 | chunk_shift = 6 60 | sample_rate = 16000 61 | model_num_frames = 399 62 | model_rf_duration = 0.025 63 | model_rf_step = 0.02 64 | 65 | [train_dataset.dataloader] 66 | batch_size = 16 67 | num_workers = 1 68 | drop_last = true 69 | pin_memory = true 70 | 71 | [validate_dataset] 72 | path = "dataset.DiarizationDataset" 73 | [validate_dataset.args] 74 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 75 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 76 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 77 | chunk_size = 8 78 | chunk_shift = 8 79 | sample_rate = 16000 80 | model_num_frames = 399 81 | model_rf_duration = 0.025 82 | model_rf_step = 0.02 83 | 84 | [validate_dataset.dataloader] 85 | batch_size = 16 86 | num_workers = 1 87 | drop_last = true 88 | pin_memory = true 89 | 90 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/conf/s80_large.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "exp" 3 | seed = 3407 4 | 5 | [finetune] 6 | finetune = false 7 | 8 | [trainer] 9 | path = "trainer_distill_prune.Trainer" 10 | [trainer.args] 11 | debug = false 12 | max_epochs = 30 13 | gradient_percentile = 90 14 | gradient_history_size = 1000 15 | save_max_score = false 16 | save_ckpt_interval = 1 17 | max_patience = 5 18 | max_num_checkpoints = 100 19 | gradient_accumulation_steps = 1 20 | validation_interval = 1 21 | freeze_wavlm = false 22 | lr_decay = false 23 | use_one_cycle_lr = false 24 | # prune 25 | use_reg = true 26 | target_sparsity = 0.8 27 | pre_train_epochs = 0 28 | sparsity_warmup_epochs = 5 29 | 30 | [optimizer] 31 | path = "torch.optim.AdamW" 32 | [optimizer.args] 33 | distill_lr = 2e-4 34 | reg_lr = 2e-2 35 | 36 | [model] 37 | path = "diarizen.models.pruning.model_distill_prune.Model" 38 | [model.args] 39 | teacher_ckpt = "/YOUR_PATH/wavlm-large-finetuned.bin" 40 | student_ckpt = "/YOUR_PATH/wavlm-large-finetuned.bin" 41 | pruning_units = "conv,head,interm" 42 | distill_layers = "0,8,16,24" 43 | 44 | [distill_loss] 45 | path = "diarizen.models.pruning.utils.DistillLoss" 46 | [distill_loss.args] 47 | l2_weight = 0 48 | l1_weight = 1 49 | cos_weight = 1 50 | cos_type = "raw" 51 | 52 | [train_dataset] 53 | path = "dataset.DiarizationDataset" 54 | [train_dataset.args] 55 | scp_file = "data/AMI_AliMeeting_AISHELL4/train/wav.scp" 56 | rttm_file = "data/AMI_AliMeeting_AISHELL4/train/rttm" 57 | uem_file = "data/AMI_AliMeeting_AISHELL4/train/all.uem" 58 | chunk_size = 8 59 | chunk_shift = 6 60 | sample_rate = 16000 61 | model_num_frames = 399 62 | model_rf_duration = 0.025 63 | model_rf_step = 0.02 64 | 65 | [train_dataset.dataloader] 66 | batch_size = 8 67 | num_workers = 1 68 | drop_last = true 69 | pin_memory = true 70 | 71 | [validate_dataset] 72 | path = "dataset.DiarizationDataset" 73 | [validate_dataset.args] 74 | scp_file = "data/AMI_AliMeeting_AISHELL4/dev/wav.scp" 75 | rttm_file = "data/AMI_AliMeeting_AISHELL4/dev/rttm" 76 | uem_file = "data/AMI_AliMeeting_AISHELL4/dev/all.uem" 77 | chunk_size = 8 78 | chunk_shift = 8 79 | sample_rate = 16000 80 | model_num_frames = 399 81 | model_rf_duration = 0.025 82 | model_rf_step = 0.02 83 | 84 | [validate_dataset.dataloader] 85 | batch_size = 16 86 | num_workers = 1 87 | drop_last = true 88 | pin_memory = true 89 | 90 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/torchmetrics/classification/equal_error_rate.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from typing import Optional 25 | 26 | import torch 27 | from pyannote.metrics.binary_classification import det_curve 28 | from torchmetrics import Metric 29 | from torchmetrics.utilities.data import dim_zero_cat 30 | 31 | 32 | class EqualErrorRate(Metric): 33 | 34 | is_differentiable: Optional[bool] = False 35 | higher_is_better: Optional[bool] = False 36 | full_state_update: bool = True 37 | 38 | def __init__(self, distances: bool = True, compute_on_cpu: bool = True, **kwargs): 39 | super().__init__(compute_on_cpu=compute_on_cpu, **kwargs) 40 | self.distances = distances 41 | self.add_state("scores", default=[], dist_reduce_fx="cat") 42 | self.add_state("y_true", default=[], dist_reduce_fx="cat") 43 | 44 | def update(self, scores: torch.Tensor, y_true: torch.Tensor) -> None: 45 | self.scores.append(scores) 46 | self.y_true.append(y_true) 47 | 48 | def compute(self) -> torch.Tensor: 49 | scores = dim_zero_cat(self.scores) 50 | y_true = dim_zero_cat(self.y_true) 51 | _, _, _, eer = det_curve(y_true.cpu(), scores.cpu(), distances=self.distances) 52 | return torch.tensor(eer) 53 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/multi_task.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from typing import Any, Callable, Tuple, Union 25 | 26 | from pyannote.audio.core.model import Specifications 27 | 28 | 29 | def map_with_specifications( 30 | specifications: Union[Specifications, Tuple[Specifications]], 31 | func: Callable, 32 | *iterables, 33 | ) -> Union[Any, Tuple[Any]]: 34 | """Compute the function using arguments from each of the iterables 35 | 36 | Returns a tuple if provided `specifications` is a tuple, 37 | otherwise returns the function return value. 38 | 39 | Parameters 40 | ---------- 41 | specifications : (tuple of) Specifications 42 | Specifications or tuple of specifications 43 | func : callable 44 | Function called for each specification with 45 | `func(*iterables[i], specifications=specifications[i])` 46 | *iterables : 47 | List of iterables with same length as `specifications`. 48 | 49 | Returns 50 | ------- 51 | output : (tuple of) `func` return value(s) 52 | """ 53 | 54 | if isinstance(specifications, Specifications): 55 | return func(*iterables, specifications=specifications) 56 | 57 | return tuple( 58 | func(*i, specifications=s) for s, *i in zip(specifications, *iterables) 59 | ) 60 | -------------------------------------------------------------------------------- /pyannote-audio/notebook/freeze.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pyannote.database import get_protocol, FileFinder\n", 10 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n", 11 | " preprocessors={\"audio\": FileFinder()})" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from pyannote.audio.tasks import VoiceActivityDetection\n", 21 | "from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel\n", 22 | "import pytorch_lightning as pl" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "vad = VoiceActivityDetection(protocol, duration=2., batch_size=16, num_workers=4)\n", 32 | "model = SimpleSegmentationModel(task=vad)\n", 33 | "trainer = pl.Trainer(max_epochs=1)\n", 34 | "_ = trainer.fit(model)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "summary = model.summarize('full')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "model.freeze_up_to('lstm')" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "model.unfreeze_up_to('mfcc.MelSpectrogram.spectrogram')" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "model.freeze_by_name(['lstm', 'activation'])" 71 | ] 72 | } 73 | ], 74 | "metadata": { 75 | "kernelspec": { 76 | "display_name": "Python 3", 77 | "language": "python", 78 | "name": "python3" 79 | }, 80 | "language_info": { 81 | "codemirror_mode": { 82 | "name": "ipython", 83 | "version": 3 84 | }, 85 | "file_extension": ".py", 86 | "mimetype": "text/x-python", 87 | "name": "python", 88 | "nbconvert_exporter": "python", 89 | "pygments_lexer": "ipython3", 90 | "version": "3.8.5" 91 | } 92 | }, 93 | "nbformat": 4, 94 | "nbformat_minor": 4 95 | } 96 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/lr_schedulers/CosineAnnealingWarmRestarts.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Optional 24 | 25 | from torch.optim import Optimizer 26 | from torch.optim.lr_scheduler import ( 27 | CosineAnnealingWarmRestarts as _CosineAnnealingWarmRestarts, 28 | ) 29 | 30 | 31 | def CosineAnnealingWarmRestarts( 32 | optimizer: Optimizer, 33 | min_lr: float = 1e-8, 34 | max_lr: float = 1e-3, 35 | patience: int = 1, 36 | num_batches_per_epoch: Optional[int] = None, 37 | **kwargs, 38 | ): 39 | """Wrapper around CosineAnnealingWarmRestarts 40 | 41 | Parameters 42 | ---------- 43 | optimizer : Optimizer 44 | Optimizer 45 | min_lr : float, optional 46 | Defaults to 1e-8. 47 | max_lr : float, optional 48 | Defaults to 1e-3 49 | patience : int, optional 50 | Number of epochs per cycle. Defaults to 1. 51 | num_batches_per_epoch : int, optional 52 | Number of batches per epoch. 53 | """ 54 | 55 | # initialize optimizer lr to max_lr 56 | for g in optimizer.param_groups: 57 | g["lr"] = max_lr 58 | 59 | num_steps = patience * num_batches_per_epoch 60 | 61 | return { 62 | "scheduler": _CosineAnnealingWarmRestarts( 63 | optimizer, num_steps, eta_min=min_lr, T_mult=2 64 | ), 65 | "interval": "step", 66 | } 67 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/lr_schedulers/CyclicLR.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from typing import Optional 24 | 25 | from torch.optim import Optimizer 26 | from torch.optim.lr_scheduler import CyclicLR as _CyclicLR 27 | 28 | 29 | def CyclicLR( 30 | optimizer: Optimizer, 31 | min_lr: float = 1e-8, 32 | max_lr: float = 1e-3, 33 | mode: str = "triangular2", 34 | patience: int = 50, 35 | num_batches_per_epoch: Optional[int] = None, 36 | **kwargs, 37 | ): 38 | """Wrapper around CyclicLR learning rate scheduler 39 | 40 | Parameters 41 | ---------- 42 | optimizer : Optimizer 43 | Optimizer 44 | min_lr : float, optional 45 | Defaults to 1e-8. 46 | max_lr : float, optional 47 | Defaults to 1e-3 48 | patience : int, optional 49 | Number of epochs per cycle. Defaults to 50. 50 | num_batches_per_epoch : int, optional 51 | Number of batches per epoch. 52 | mode : {"triangular", "triangular2"}, optional 53 | Defaults to "triangular2". 54 | """ 55 | 56 | step_size_up = int(0.5 * patience * num_batches_per_epoch) 57 | 58 | return { 59 | "scheduler": _CyclicLR( 60 | optimizer, 61 | base_lr=min_lr, 62 | max_lr=max_lr, 63 | step_size_up=step_size_up, 64 | mode=mode, 65 | cycle_momentum=False, 66 | ), 67 | "interval": "step", 68 | } 69 | -------------------------------------------------------------------------------- /recipes/diar_ssl_pruning/README.md: -------------------------------------------------------------------------------- 1 | # Structured Pruning of WavLM 2 | This directory contains scripts for structured pruning of [WavLM](https://arxiv.org/pdf/2110.13900) applied to speaker diarization. 3 | 4 | ## How to run 5 | - Convert WavLM format from HuggingFace to our custom format: See `convert_wavlm_from_hf.py`. 6 | - Fine-tune WavLM for diarization: See `../diar_ssl/run_stage.sh`. 7 | - Start pruning training: 8 | `bash -i run_stage.sh`. 9 | 10 | 11 | ## Results (collar=0s) 12 | | System | Sparsity | Params | MACs | Speedup | AMI | AISHELL-4 | AliMeeting | Macro | 13 | |:----------------|:----------:|:----------:|:---------:|:---------:|:------:|:------------:|:-------------:|:--------:| 14 | | Fbank | - | - | - | - | 19.7 | 12.5 | 21.0 | 17.7 | 15 | | WavLM Base+ | 0% | 94.4M | 6.9G | - | 15.6 | 11.8 | 17.7 | 15.0 | 16 | | | 80% | 18.8M | 1.1G | 4.0× | 15.7 | 12.1 | 17.9 | 15.2 | 17 | | | 90% | 9.4M | 0.6G | 5.7× | 17.2 | 12.1 | 19.2 | 16.1 | 18 | | WavLM Large | 0% | 316.6M | 17.8G | - | 14.8 | 11.3 | 16.3 | 14.1 | 19 | | | 80% | 63.3M | 3.8G | 2.6× | 15.1 | 11.3 | 15.8 | 14.1 | 20 | | | 90% | 30.6M | 1.8G | 3.5× | 15.7 | 11.2 | 17.6 | 14.8 | 21 | 22 | ## Citation 23 | If you found this work helpful, please consider citing: 24 | J. Han, F. Landini, J. Rohdin, A. Silnova, M. Diez, J. Cernocky and L. Burget, [Fine-tune Before Structured Pruning: Towards Compact and Accurate Self-Supervised Models for Speaker Diarization](https://arxiv.org/pdf/2505.24111), in Proc. INTERSPEECH, 2025. 25 | ``` 26 | @article{han2025fine, 27 | title={Fine-tune Before Structured Pruning: Towards Compact and Accurate Self-Supervised Models for Speaker Diarization}, 28 | author={Han, Jiangyu and Landini, Federico and Rohdin, Johan and Silnova, Anna and Diez, Mireia and Cernocky, Jan and Burget, Lukas}, 29 | journal={arXiv preprint arXiv:2505.24111}, 30 | year={2025} 31 | } 32 | 33 | @article{han2025efficient, 34 | title={Efficient and Generalizable Speaker Diarization via Structured Pruning of Self-Supervised Models}, 35 | author={Han, Jiangyu and P{\'a}lka, Petr and Delcroix, Marc and Landini, Federico and Rohdin, Johan and Cernock{\`y}, Jan and Burget, Luk{\'a}{\v{s}}}, 36 | journal={arXiv preprint arXiv:2506.18623}, 37 | year={2025} 38 | } 39 | ``` 40 | 41 | ## Acknowledgments 42 | We thank the authors of [DPHuBERT](https://github.com/pyf98/DPHuBERT) for open-sourcing their code. 43 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/models/embedding/wespeaker/convert.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # Script used to convert from WeSpeaker to pyannote.audio 24 | 25 | import sys 26 | from pathlib import Path 27 | 28 | import pytorch_lightning as pl 29 | import torch 30 | 31 | import pyannote.audio.models.embedding.wespeaker as wespeaker 32 | from pyannote.audio import Model 33 | from pyannote.audio.core.task import Problem, Resolution, Specifications 34 | 35 | wespeaker_checkpoint_dir = sys.argv[1] # /path/to/wespeaker_cnceleb-resnet34-LM 36 | 37 | wespeaker_checkpoint = Path(wespeaker_checkpoint_dir) / "wespeaker.pt" 38 | 39 | depth = Path(wespeaker_checkpoint_dir).parts[-1].split("-")[-2][6:] # '34' 40 | Klass = getattr(wespeaker, f"WeSpeakerResNet{depth}") # WeSpeakerResNet34 41 | 42 | duration = 5.0 # whatever 43 | specifications = Specifications( 44 | problem=Problem.REPRESENTATION, resolution=Resolution.CHUNK, duration=duration 45 | ) 46 | 47 | state_dict = torch.load(wespeaker_checkpoint, map_location=torch.device("cpu")) 48 | state_dict.pop("projection.weight") 49 | 50 | model = Klass() 51 | model.resnet.load_state_dict(state_dict, strict=True) 52 | model.specifications = specifications 53 | 54 | checkpoint = {"state_dict": model.state_dict()} 55 | model.on_save_checkpoint(checkpoint) 56 | checkpoint["pytorch-lightning_version"] = pl.__version__ 57 | 58 | pyannote_checkpoint = Path(wespeaker_checkpoint_dir) / "pytorch_model.bin" 59 | torch.save(checkpoint, pyannote_checkpoint) 60 | 61 | model = Model.from_pretrained(pyannote_checkpoint) 62 | print(model) 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | exp* 3 | *.log 4 | *sge* 5 | *.err 6 | *.out 7 | *_debug* 8 | .idea/ 9 | .vscode/* 10 | !.vscode/launch.json 11 | generated/ # generated sphinx files 12 | _build/ 13 | *.csv 14 | core.* 15 | tmp/ 16 | 17 | # wavlm 18 | WavLM-Base+.pt 19 | 20 | # Ruff 21 | .ruff_cache/ 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | pip-wheel-metadata/ 46 | share/python-wheels/ 47 | *.egg-info/ 48 | .installed.cfg 49 | *.egg 50 | MANIFEST 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .nox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | *.py,cover 73 | .hypothesis/ 74 | .pytest_cache/ 75 | 76 | # Translations 77 | *.mo 78 | *.pot 79 | 80 | # Django stuff: 81 | *.log 82 | local_settings.py 83 | db.sqlite3 84 | db.sqlite3-journal 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # pipenv 110 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 111 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 112 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 113 | # install all needed dependencies. 114 | #Pipfile.lock 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | -------------------------------------------------------------------------------- /pyannote-audio/tests/io_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from pyannote.core import Segment 4 | from torch import Tensor 5 | 6 | from pyannote.audio.core.io import Audio 7 | 8 | 9 | def test_audio_resample(): 10 | "Audio is correctly resampled when it isn't the correct sample rate" 11 | test_file = "tests/data/dev00.wav" 12 | info = torchaudio.info(test_file) 13 | old_sr = info.sample_rate 14 | loader = Audio(sample_rate=old_sr // 2, mono="downmix") 15 | wav, sr = loader(test_file) 16 | assert isinstance(wav, Tensor) 17 | assert sr == old_sr // 2 18 | 19 | 20 | def test_basic_load_with_defaults(): 21 | test_file = "tests/data/dev00.wav" 22 | loader = Audio(mono="downmix") 23 | wav, sr = loader(test_file) 24 | assert isinstance(wav, Tensor) 25 | 26 | 27 | def test_correct_audio_channel(): 28 | "When we specify an audio channel, it is chosen correctly" 29 | waveform = torch.rand(2, 16000 * 2) 30 | loader = Audio(mono="downmix") 31 | wav, sr = loader({"waveform": waveform, "sample_rate": 16000, "channel": 1}) 32 | assert torch.equal(wav, waveform[1:2]) 33 | assert sr == 16000 34 | 35 | 36 | def test_can_load_with_waveform(): 37 | "We can load a raw waveform" 38 | waveform = torch.rand(2, 16000 * 2) 39 | loader = Audio(mono="downmix") 40 | wav, sr = loader({"waveform": waveform, "sample_rate": 16000}) 41 | assert isinstance(wav, Tensor) 42 | assert sr == 16000 43 | 44 | 45 | def test_can_crop(): 46 | "Cropping works when we give a Segment" 47 | test_file = "tests/data/dev00.wav" 48 | loader = Audio(mono="downmix") 49 | segment = Segment(0.2, 0.7) 50 | wav, sr = loader.crop(test_file, segment) 51 | assert wav.shape[1] / sr == 0.5 52 | 53 | 54 | def test_can_crop_waveform(): 55 | "Cropping works on raw waveforms" 56 | waveform = torch.rand(1, 16000 * 2) 57 | loader = Audio(mono="downmix") 58 | segment = Segment(0.2, 0.7) 59 | wav, sr = loader.crop({"waveform": waveform, "sample_rate": 16000}, segment) 60 | assert isinstance(wav, Tensor) 61 | assert sr == 16000 62 | 63 | 64 | # File Like Object Tests 65 | def test_can_load_from_file_like(): 66 | "Load entire wav of file like" 67 | loader = Audio(mono="downmix") 68 | 69 | with open("tests/data/dev00.wav", "rb") as f: 70 | wav, sr = loader(f) 71 | 72 | assert isinstance(wav, Tensor) 73 | assert sr == 16000 74 | 75 | 76 | def test_can_crop_from_file_like(): 77 | "Load cropped sections from file like objects" 78 | loader = Audio(mono="downmix") 79 | 80 | with open("tests/data/dev00.wav", "rb") as f: 81 | segment = Segment(0.2, 0.7) 82 | wav, sr = loader.crop(f, segment) 83 | 84 | assert isinstance(wav, Tensor) 85 | assert sr == 16000 86 | assert wav.shape[1] == 0.5 * 16000 87 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/lr_schedulers/ReduceLROnPlateau.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from typing import Optional, Text 25 | 26 | from torch.optim import Optimizer 27 | from torch.optim.lr_scheduler import ReduceLROnPlateau as _ReduceLROnPlateau 28 | 29 | 30 | def ReduceLROnPlateau( 31 | optimizer: Optimizer, 32 | monitor: Optional[Text] = None, 33 | direction: Optional[Text] = "min", 34 | min_lr: float = 1e-8, 35 | max_lr: float = 1e-3, 36 | factor: float = 0.5, 37 | patience: int = 50, 38 | **kwargs, 39 | ): 40 | """Wrapper around ReduceLROnPlateau learning rate scheduler 41 | 42 | Parameters 43 | ---------- 44 | optimizer : Optimizer 45 | Optimizer 46 | min_lr : float, optional 47 | Defaults to 1e-8. 48 | max_lr : float, optional 49 | Defaults to 1e-3 50 | factor : float, optional 51 | Defaults to 0.5 52 | patience : int, optional 53 | Wait that many epochs with no improvement before reducing the learning rate. 54 | Defaults to 50. 55 | monitor : str, optional 56 | Value to monitor 57 | direction : {"min", "max"}, optional 58 | "min" (resp. "max") means smaller (resp. larger) is better. 59 | """ 60 | 61 | # initialize optimizer lr to max_lr 62 | for g in optimizer.param_groups: 63 | g["lr"] = max_lr 64 | 65 | return { 66 | "scheduler": _ReduceLROnPlateau( 67 | optimizer, 68 | mode=direction, 69 | factor=factor, 70 | patience=patience, 71 | threshold=0.0001, 72 | threshold_mode="rel", 73 | cooldown=0, 74 | min_lr=min_lr, 75 | eps=1e-08, 76 | verbose=False, 77 | ), 78 | "interval": "epoch", 79 | "monitor": monitor, 80 | "strict": True, 81 | } 82 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # ----------------- Build System ----------------- 2 | [build-system] 3 | requires = ["flit_core >=3.2,<4"] 4 | build-backend = "flit_core.buildapi" 5 | 6 | # ----------------- Metadata ----------------- 7 | [project] 8 | name = "diarizen" 9 | description = """DiariZen is a speaker diarization toolkit based on AudioZen . 10 | The AudioZen is mainly maintained by Xiang Hao (Hong Kong Polytechnic University). 11 | The DiariZen is mainly maintained by Jiangyu Han (Brno University of Technology)""" 12 | authors = [ 13 | { name = "Xiang Hao", email = "haoxiangsnr@gmail.com" }, 14 | { name = "Jiangyu Han", email = "ihan@fit.vut.cz"} 15 | ] 16 | readme = "README.md" 17 | requires-python = ">=3.10" 18 | version = "0.0.1" 19 | classifiers = [ 20 | "Programming Language :: Python :: 3.10", 21 | "Development Status :: 2 - Pre-Alpha", 22 | "License :: OSI Approved :: MIT License", 23 | "Environment :: GPU :: NVIDIA CUDA", 24 | "Operating System :: OS Independent", 25 | ] 26 | keywords = [ 27 | "multiple purposes", 28 | "speaker localization/tracking", 29 | "dereverberation", 30 | "enhancement", 31 | "separation", 32 | "recognition", 33 | "diarization" 34 | ] 35 | [project.optional-dependencies] 36 | test = ["pytest", "pytest-cov"] 37 | docs = ["importlib_metadata", "sphinx-autoapi", "sphinx-rtd-theme", "myst-parser", "myst-nb"] 38 | build = ["flit", "python-semantic-release", "sphinx-autobuild"] 39 | [project.urls] 40 | Source = "https://github.com/BUTSpeechFIT/DiariZen" 41 | 42 | # ----------------- Tools Configuration ----------------- 43 | [tool.semantic_release] 44 | version_toml = "pyproject.toml:project.version" # version location 45 | branch = "main" # branch to make releases 46 | changelog_file = "CHANGELOG.md" # changelog file` 47 | build_command = "flit build" # build dists 48 | upload_to_release = true # auto-create GitHub release 49 | upload_to_repository = false # don't auto-upload to PyPI 50 | remove_dist = false # don't remove dists 51 | patch_without_tag = false # patch release by default 52 | commit_author = "Jiangyu Han " 53 | commit_subject = "Release {version}" 54 | commit_message = "" # commit message 55 | 56 | [tool.ruff] 57 | # Never enforce `E501` (line length violations). 58 | ignore = ["C901", "E501", "E741", "F402", "F823"] 59 | select = ["C", "E", "F", "I", "W"] 60 | line-length = 119 61 | 62 | # Ignore import violations in all `__init__.py` files. 63 | [tool.ruff.per-file-ignores] 64 | "__init__.py" = ["E402", "F401", "F403", "F811"] 65 | 66 | [tool.ruff.isort] 67 | lines-after-imports = 2 68 | known-first-party = ["diarizen"] 69 | 70 | [tool.ruff.format] 71 | # Like Black, use double quotes for strings. 72 | quote-style = "double" 73 | # Like Black, indent with spaces, rather than tabs. 74 | indent-style = "space" 75 | # Like Black, respect magic trailing commas. 76 | skip-magic-trailing-comma = false 77 | # Like Black, automatically detect the appropriate line ending. 78 | line-ending = "auto" 79 | docstring-code-format = true 80 | docstring-code-line-length = 119 81 | 82 | [tool.ruff.lint.pydocstyle] 83 | convention = "google" 84 | -------------------------------------------------------------------------------- /diarizen/logger.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copy from https://github.com/haoxiangsnr/spiking-fullsubnet/blob/main/audiozen/logger.py 3 | # Copyright 2024 Hong Kong Polytechnic University (author: Xiang Hao, haoxiangsnr@gmail.com) 4 | 5 | import logging 6 | import os 7 | import time 8 | from pathlib import Path 9 | 10 | import toml 11 | from torch.utils.tensorboard import SummaryWriter # type: ignore 12 | 13 | 14 | class TensorboardLogger(SummaryWriter): 15 | def __init__(self, log_dir: str = "") -> None: 16 | super().__init__(log_dir=log_dir, max_queue=5, flush_secs=30) 17 | 18 | def log_config(self, config: dict) -> None: 19 | self.add_text( 20 | tag="Configuration", 21 | text_string=f"
  \n{toml.dumps(config)}  \n
", 22 | global_step=1, 23 | ) 24 | 25 | 26 | def init_logging_logger(config): 27 | """Initialize logging logger with handlers. 28 | 29 | Args: 30 | log_fpath: Path to save log file. 31 | 32 | Examples: 33 | >>> # Call this function at the beginning of main file. 34 | >>> init_logger(log_fpath="log_path") 35 | >>> # Use this logger in other modules. 36 | >>> import logging 37 | >>> logger = logging.getLogger(__name__) 38 | >>> logger.info("info message") 39 | """ 40 | # Parse log_fpath 41 | log_dir: Path = Path(config["meta"]["save_dir"]).expanduser().absolute() / config["meta"]["exp_id"] 42 | log_dir.mkdir(parents=True, exist_ok=True) 43 | 44 | # disable logging for libraries that use standard logging module 45 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 46 | logging.getLogger("numba").setLevel(logging.ERROR) 47 | 48 | # Create logger 49 | logger = logging.getLogger() 50 | 51 | # Set the lowest level of root logger and controls logging via handlers' level 52 | logger.setLevel(logging.DEBUG) 53 | 54 | # Get log level from environment variable 55 | log_level = os.environ.get("LOG_LEVEL", "INFO").upper() 56 | 57 | # Create a console handler and set level to info 58 | console_handler = logging.StreamHandler() 59 | console_handler.setLevel(level=log_level) 60 | 61 | # Create a file handler and set the logger level to debug 62 | time_now = time.strftime("%Y_%m_%d--%H_%M_%S") 63 | file_handler = logging.FileHandler(str(log_dir / f"{config['meta']['exp_id']}_{time_now}.log")) 64 | file_handler.setLevel(level=log_level) 65 | 66 | # Create formatters (file logger have more info) 67 | console_formatter = logging.Formatter( 68 | "%(asctime)s: %(message)s", 69 | datefmt="%m-%d %H:%M:%S", 70 | ) 71 | file_formatter = logging.Formatter( 72 | "%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d]: %(message)s", 73 | datefmt="%m-%d %H:%M:%S", 74 | ) 75 | 76 | # Add formatter to ch 77 | console_handler.setFormatter(console_formatter) 78 | file_handler.setFormatter(file_formatter) 79 | 80 | # Add ch to logger 81 | logger.addHandler(console_handler) 82 | logger.addHandler(file_handler) 83 | 84 | logger.info(f"Initialized logger with log file in {log_dir.as_posix()}.") 85 | -------------------------------------------------------------------------------- /pyannote-audio/tests/inference_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import pytorch_lightning as pl 4 | from pyannote.core import SlidingWindowFeature 5 | from pyannote.database import FileFinder, get_protocol 6 | 7 | from pyannote.audio import Inference, Model 8 | from pyannote.audio.core.task import Resolution 9 | from pyannote.audio.models.segmentation.debug import SimpleSegmentationModel 10 | from pyannote.audio.tasks import VoiceActivityDetection 11 | 12 | HF_SAMPLE_MODEL_ID = "pyannote/ci-segmentation" 13 | 14 | 15 | def test_hf_download_inference(): 16 | inference = Inference(HF_SAMPLE_MODEL_ID, device="cpu") 17 | assert isinstance(inference, Inference) 18 | 19 | 20 | def test_hf_download_model(): 21 | model = Model.from_pretrained(HF_SAMPLE_MODEL_ID) 22 | assert isinstance(model, Model) 23 | 24 | 25 | @pytest.fixture() 26 | def trained(): 27 | protocol = get_protocol( 28 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()} 29 | ) 30 | vad = VoiceActivityDetection(protocol, duration=2.0, batch_size=16, num_workers=4) 31 | model = SimpleSegmentationModel(task=vad) 32 | trainer = pl.Trainer(fast_dev_run=True, accelerator="cpu") 33 | trainer.fit(model) 34 | return protocol, model 35 | 36 | 37 | @pytest.fixture() 38 | def pretrained_model(): 39 | return Model.from_pretrained(HF_SAMPLE_MODEL_ID) 40 | 41 | 42 | @pytest.fixture() 43 | def dev_file(): 44 | protocol = get_protocol( 45 | "Debug.SpeakerDiarization.Debug", preprocessors={"audio": FileFinder()} 46 | ) 47 | return next(protocol.development()) 48 | 49 | 50 | def test_duration_warning(trained): 51 | protocol, model = trained 52 | with pytest.warns(UserWarning): 53 | duration = model.specifications.duration 54 | new_duration = duration + 1 55 | Inference(model, duration=new_duration, step=0.1, batch_size=128) 56 | 57 | 58 | def test_step_check_warning(trained): 59 | protocol, model = trained 60 | with pytest.raises(ValueError): 61 | duration = model.specifications.duration 62 | Inference(model, step=duration + 1, batch_size=128) 63 | 64 | 65 | def test_invalid_window_fails(trained): 66 | protocol, model = trained 67 | with pytest.raises(ValueError): 68 | Inference(model, window="unknown") 69 | 70 | 71 | def test_invalid_resolution_fails(trained): 72 | protocol, model = trained 73 | with pytest.warns(UserWarning): 74 | model.specifications.resolution = Resolution.FRAME 75 | Inference(model, window="whole", batch_size=128) 76 | 77 | 78 | def test_whole_window_slide(trained): 79 | protocol, model = trained 80 | inference = Inference(model, window="whole", batch_size=128) 81 | dev_file = next(protocol.development()) 82 | output = inference(dev_file) 83 | assert isinstance(output, np.ndarray) 84 | 85 | 86 | def test_on_file_path(trained): 87 | protocol, model = trained 88 | inference = Inference(model, batch_size=128) 89 | output = inference("tests/data/dev00.wav") 90 | assert isinstance(output, SlidingWindowFeature) 91 | 92 | 93 | def test_skip_aggregation(pretrained_model, dev_file): 94 | inference = Inference(pretrained_model, skip_aggregation=True) 95 | scores = inference(dev_file) 96 | assert len(scores.data.shape) == 3 97 | -------------------------------------------------------------------------------- /pyannote-audio/tests/utils/test_permutation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pyannote.audio.utils.permutation import permutate 5 | 6 | 7 | def test_permutate_torch(): 8 | 9 | num_frames, num_speakers = 10, 3 10 | 11 | actual_permutations = [ 12 | (0, 1, 2), 13 | (0, 2, 1), 14 | (1, 0, 2), 15 | (1, 2, 0), 16 | (2, 0, 1), 17 | (2, 1, 0), 18 | ] 19 | batch_size = len(actual_permutations) 20 | 21 | y2 = torch.randn((num_frames, num_speakers)) 22 | y1 = torch.zeros((batch_size, num_frames, num_speakers)) 23 | 24 | for p, permutation in enumerate(actual_permutations): 25 | y1[p] = y2[:, permutation] 26 | 27 | permutated_y2, permutations = permutate(y1, y2) 28 | assert actual_permutations == permutations 29 | 30 | for p, permutation in enumerate(actual_permutations): 31 | np.testing.assert_allclose(permutated_y2[p], y2[:, permutation]) 32 | 33 | 34 | def test_permutate_numpy(): 35 | 36 | num_frames, num_speakers = 10, 3 37 | 38 | actual_permutations = [ 39 | (0, 1, 2), 40 | (0, 2, 1), 41 | (1, 0, 2), 42 | (1, 2, 0), 43 | (2, 0, 1), 44 | (2, 1, 0), 45 | ] 46 | batch_size = len(actual_permutations) 47 | 48 | y2 = np.random.randn(num_frames, num_speakers) 49 | y1 = np.zeros((batch_size, num_frames, num_speakers)) 50 | 51 | for p, permutation in enumerate(actual_permutations): 52 | y1[p] = y2[:, permutation] 53 | 54 | permutated_y2, permutations = permutate(y1, y2) 55 | assert actual_permutations == permutations 56 | 57 | for p, permutation in enumerate(actual_permutations): 58 | np.testing.assert_allclose(permutated_y2[p], y2[:, permutation]) 59 | 60 | 61 | def test_permutate_less_speakers(): 62 | 63 | num_frames = 10 64 | 65 | actual_permutations = [ 66 | (0, 1, None), 67 | (0, None, 1), 68 | (1, 0, None), 69 | (1, None, 0), 70 | (None, 0, 1), 71 | (None, 1, 0), 72 | ] 73 | batch_size = len(actual_permutations) 74 | 75 | y2 = np.random.randn(num_frames, 2) 76 | y1 = np.zeros((batch_size, num_frames, 3)) 77 | 78 | for p, permutation in enumerate(actual_permutations): 79 | for i, j in enumerate(permutation): 80 | if j is not None: 81 | y1[p, :, i] = y2[:, j] 82 | 83 | permutated_y2, permutations = permutate(y1, y2) 84 | 85 | assert permutations == actual_permutations 86 | 87 | 88 | def test_permutate_more_speakers(): 89 | 90 | num_frames = 10 91 | 92 | actual_permutations = [ 93 | (0, 1), 94 | (0, 2), 95 | (1, 0), 96 | (1, 2), 97 | (2, 0), 98 | (2, 1), 99 | ] 100 | batch_size = len(actual_permutations) 101 | 102 | y2 = np.random.randn(num_frames, 3) 103 | y1 = np.zeros((batch_size, num_frames, 2)) 104 | 105 | for p, permutation in enumerate(actual_permutations): 106 | for i, j in enumerate(permutation): 107 | y1[p, :, i] = y2[:, j] 108 | 109 | permutated_y2, permutations = permutate(y1, y2) 110 | 111 | assert permutations == actual_permutations 112 | np.testing.assert_allclose(permutated_y2, y1) 113 | -------------------------------------------------------------------------------- /pyannote-audio/notebook/augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# gett a 5s excerpt of first test file\n", 10 | "from pyannote.database import get_protocol, FileFinder\n", 11 | "protocol = get_protocol('Debug.SpeakerDiarization.Debug', \n", 12 | " preprocessors={\"audio\": FileFinder()})\n", 13 | "\n", 14 | "from pyannote.audio.core.io import Audio\n", 15 | "audio = Audio(sample_rate=16000, mono=\"downmix\")\n", 16 | "file = next(protocol.test())\n", 17 | "\n", 18 | "from pyannote.core import Segment\n", 19 | "waveform, sample_rate = audio.crop(file, Segment(5, 10))\n", 20 | "\n", 21 | "import torch\n", 22 | "waveforms = torch.tensor(waveform)[None, :]" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# play the excerpt\n", 32 | "from IPython.display import Audio as Play\n", 33 | "Play(waveforms.squeeze(), rate=sample_rate, normalize=False, autoplay=True)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# define a model that simply returns the waveform\n", 43 | "from pyannote.audio.core.model import Model\n", 44 | "class Passthrough(Model):\n", 45 | " def forward(self, waveforms):\n", 46 | " return waveforms\n", 47 | " \n", 48 | "identity = Passthrough()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# pass the waveform through this \"identity\" model\n", 58 | "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# add one torch_audiomentations waveform transform to the model\n", 68 | "from pyannote.audio.augmentation.registry import register_augmentation\n", 69 | "from torch_audiomentations import Gain\n", 70 | "gain = Gain(\n", 71 | " min_gain_in_db=-15.0,\n", 72 | " max_gain_in_db=5.0,\n", 73 | " p=0.5)\n", 74 | "register_augmentation(gain, identity, when='input')" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# pass the waveform through the \"augmented\" model\n", 84 | "Play(identity(waveforms).squeeze(), rate=sample_rate, normalize=False, autoplay=True)" 85 | ] 86 | } 87 | ], 88 | "metadata": { 89 | "kernelspec": { 90 | "display_name": "Python 3", 91 | "language": "python", 92 | "name": "python3" 93 | }, 94 | "language_info": { 95 | "codemirror_mode": { 96 | "name": "ipython", 97 | "version": 3 98 | }, 99 | "file_extension": ".py", 100 | "mimetype": "text/x-python", 101 | "name": "python", 102 | "nbconvert_exporter": "python", 103 | "pygments_lexer": "ipython3", 104 | "version": "3.7.9" 105 | } 106 | }, 107 | "nbformat": 4, 108 | "nbformat_minor": 4 109 | } 110 | -------------------------------------------------------------------------------- /diarizen/models/pruning/model_distill_prune.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copyright 2025 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) 3 | # This work is inspired by: https://github.com/pyf98/DPHuBERT 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from diarizen.models.module.wav2vec2.model import wav2vec2_model 9 | 10 | class Model(nn.Module): 11 | def __init__( 12 | self, 13 | teacher_ckpt: str, 14 | student_ckpt: str, 15 | pruning_units: str = "conv,head,interm", 16 | distill_layers: str = "0,4,8,12", 17 | ): 18 | super().__init__() 19 | 20 | self.distill_layers = [int(l) for l in distill_layers.split(",")] 21 | 22 | self.teacher_model = self.build_teacher(teacher_ckpt) 23 | self.student_model, self.student_config = self.build_student(student_ckpt, pruning_units) 24 | 25 | def build_teacher(self, teacher_ckpt): 26 | teacher_ckpt = torch.load(teacher_ckpt, map_location="cpu") 27 | teacher_model = wav2vec2_model(**teacher_ckpt["config"]) 28 | teacher_result = teacher_model.load_state_dict(teacher_ckpt["state_dict"], strict=False) 29 | print(f"Load pretrained ckpt to teacher: missing {teacher_result.missing_keys}, unexpected {teacher_result.unexpected_keys}") 30 | 31 | # freeze teacher 32 | for p in teacher_model.parameters(): 33 | p.requires_grad = False 34 | print("Freeze parameters of the teacher model by setting requires_grad=False") 35 | teacher_model.eval() 36 | return teacher_model 37 | 38 | def build_student(self, student_ckpt, pruning_units): 39 | student_ckpt = torch.load(student_ckpt, map_location="cpu") 40 | pruning_units = pruning_units.split(",") 41 | print(f"Pruning units: {pruning_units}") 42 | student_config = student_ckpt['config'] 43 | student_config.update( 44 | dict( 45 | extractor_prune_conv_channels = "conv" in pruning_units, 46 | encoder_prune_attention_heads = "head" in pruning_units, 47 | encoder_prune_attention_layer = "attlayer" in pruning_units, 48 | encoder_prune_feed_forward_intermediate = "interm" in pruning_units, 49 | encoder_prune_feed_forward_layer = "ffnlayer" in pruning_units, 50 | ) 51 | ) 52 | student_model = wav2vec2_model(**student_config) 53 | student_result = student_model.load_state_dict(student_ckpt["state_dict"], strict=False) 54 | print(f"Load pretrained ckpt to student: missing {student_result.missing_keys}, unexpected {student_result.unexpected_keys}") 55 | return student_model, student_config 56 | 57 | def forward(self, waveforms: torch.Tensor) -> torch.Tensor: 58 | self.teacher_model.eval() 59 | with torch.no_grad(): 60 | teacher_hiddens, _ = self.teacher_model.extract_features(waveforms) 61 | teacher_hiddens = torch.stack( 62 | [teacher_hiddens[idx] for idx in self.distill_layers], dim=1 63 | ) # (batch, layer, time, feature) 64 | 65 | student_hiddens, _ = self.student_model.extract_features(waveforms) 66 | student_hiddens = torch.stack( 67 | [student_hiddens[idx] for idx in self.distill_layers], dim=1 68 | ) # (batch, layer, time, feature) 69 | 70 | return student_hiddens, teacher_hiddens -------------------------------------------------------------------------------- /pyannote-audio/setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files 4 | 5 | [metadata] 6 | name = pyannote-audio 7 | description = Neural speaker diarization 8 | author = Herve Bredin 9 | author-email = herve.bredin@irit.fr 10 | license = mit 11 | long-description = file: README.md 12 | long-description-content-type = text/markdown; charset=UTF-8; variant=GFM 13 | # Change if running only on Windows, Mac or Linux (comma-separated) 14 | platforms = Linux, Mac 15 | # Add here all kinds of additional classifiers as defined under 16 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 17 | classifiers = 18 | Development Status :: 4 - Beta 19 | Programming Language :: Python 20 | 21 | [options] 22 | zip_safe = False 23 | packages = find: 24 | include_package_data = True 25 | # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! 26 | setup_requires = pyscaffold>=3.2a0,<3.3a0 27 | # Add here dependencies of your project (semicolon/line-separated), e.g. 28 | # install_requires = numpy; scipy 29 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4 30 | python_requires = >=3.7 31 | 32 | [options.packages.find] 33 | where = . 34 | exclude = 35 | tests 36 | 37 | [options.extras_require] 38 | # Add here additional requirements for extra features, to install with: 39 | # `pip install fastaudio[PDF]` like: 40 | # PDF = ReportLab; RXP 41 | # Add here test requirements (semicolon/line-separated) 42 | testing = 43 | pytest>=6.0 44 | pytest-cov>=2.10 45 | jupyter 46 | papermill 47 | dev = 48 | pre_commit>=2.7 49 | recommonmark>=0.6 50 | black>=22.3.0 51 | cli = 52 | hydra-core >=1.1,<1.2 53 | typer >= 0.4.0,<0.5.0 54 | 55 | [options.entry_points] 56 | 57 | console_scripts = 58 | pyannote-audio-train=pyannote.audio.cli.train:train 59 | pyannote-audio-eval=pyannote.audio.cli.evaluate:evaluate 60 | 61 | 62 | [test] 63 | # py.test options when running `python setup.py test` 64 | # addopts = --verbose 65 | extras = True 66 | 67 | [tool:pytest] 68 | # Options for py.test: 69 | # Specify command line options as you would do when invoking py.test directly. 70 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 71 | # in order to write a coverage file that can be read by Jenkins. 72 | addopts = 73 | --cov pyannote --cov-report term-missing 74 | --verbose 75 | norecursedirs = 76 | dist 77 | build 78 | .tox 79 | testpaths = tests 80 | 81 | [aliases] 82 | dists = bdist_wheel 83 | 84 | [bdist_wheel] 85 | # Use this option if your package is pure-python 86 | universal = 1 87 | 88 | [build_sphinx] 89 | source_dir = doc 90 | build_dir = build/sphinx 91 | 92 | [devpi:upload] 93 | # Options for the devpi: PyPI server and packaging tool 94 | # VCS export must be deactivated since we are using setuptools-scm 95 | no-vcs = 1 96 | formats = bdist_wheel 97 | 98 | [flake8] 99 | # Some sane defaults for the code style checker flake8 100 | exclude = 101 | .tox 102 | build 103 | dist 104 | .eggs 105 | docs/conf.py 106 | 107 | [pyscaffold] 108 | # PyScaffold's parameters when the project was created. 109 | # This will be used when updating. Do not change! 110 | version = 3.2.3 111 | package = pyannote-audio 112 | extensions = 113 | markdown 114 | no_skeleton 115 | pre_commit 116 | dsproject 117 | -------------------------------------------------------------------------------- /pyannote-audio/tests/utils/test_powerset.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | import torch 25 | 26 | from pyannote.audio.utils.powerset import Powerset 27 | 28 | 29 | def test_roundtrip(): 30 | for num_classes in range(2, 5): 31 | for max_set_size in range(1, num_classes + 1): 32 | powerset = Powerset(num_classes, max_set_size) 33 | 34 | # simulate a sequence where each frame is assigned to a different powerset class 35 | one_sequence = [ 36 | [0] * powerset.num_powerset_classes 37 | for _ in range(powerset.num_powerset_classes) 38 | ] 39 | for i in range(powerset.num_powerset_classes): 40 | one_sequence[i][i] = 1.0 41 | 42 | # make a batch out of this sequence and the same sequence in reverse order 43 | batch_powerset = torch.tensor([one_sequence, one_sequence[::-1]]) 44 | 45 | # convert from powerset to multi-label 46 | batch_multilabel = powerset.to_multilabel(batch_powerset) 47 | 48 | # convert batch back to powerset 49 | reconstruction = powerset.to_powerset(batch_multilabel) 50 | 51 | assert torch.equal(batch_powerset, reconstruction) 52 | 53 | 54 | def test_permutate_powerset(): 55 | for num_classes in range(1, 6): 56 | for max_set_size in range(1, num_classes + 1): 57 | powerset = Powerset(num_classes, max_set_size) 58 | 59 | # create (num_powerset_class, num_powerset_class)-shaped tensor, where each frame is assigned to a different powerset class 60 | # and convert it to its multi-label equivalent 61 | t1 = torch.nn.functional.one_hot( 62 | torch.arange(powerset.num_powerset_classes), 63 | powerset.num_powerset_classes, 64 | ) 65 | t1_ml = powerset.to_multilabel(t1) 66 | 67 | # then permutate the powerset class in powerset space AND the multilabel equivalent in its native space 68 | # and check it has the same result. 69 | # perm = torch.randperm(num_classes) 70 | perm = tuple(torch.randperm(num_classes).tolist()) 71 | t1_ml_perm = t1_ml[:, perm] 72 | perm_ps = powerset.permutation_mapping[perm] 73 | t1_ps_perm = t1[..., perm_ps] 74 | t1_ps_perm_ml = powerset.to_multilabel(t1_ps_perm) 75 | 76 | assert t1_ml_perm.equal(t1_ps_perm_ml) 77 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/AMI_AliMeeting_AISHELL4/dev/wav.scp: -------------------------------------------------------------------------------- 1 | ES2011a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011a.wav 2 | ES2011b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011b.wav 3 | ES2011c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011c.wav 4 | ES2011d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/ES2011d.wav 5 | IB4001 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4001.wav 6 | IB4002 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4002.wav 7 | IB4003 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4003.wav 8 | IB4004 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4004.wav 9 | IB4010 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4010.wav 10 | IB4011 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IB4011.wav 11 | IS1008a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008a.wav 12 | IS1008b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008b.wav 13 | IS1008c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008c.wav 14 | IS1008d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/IS1008d.wav 15 | TS3004a /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004a.wav 16 | TS3004b /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004b.wav 17 | TS3004c /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004c.wav 18 | TS3004d /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/TS3004d.wav 19 | R8001_M8004_MS801 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8001_M8004_MS801.wav 20 | R8003_M8001_MS801 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8003_M8001_MS801.wav 21 | R8007_M8010_MS803 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8007_M8010_MS803.wav 22 | R8007_M8011_MS806 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8007_M8011_MS806.wav 23 | R8008_M8013_MS807 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8008_M8013_MS807.wav 24 | R8009_M8018_MS809 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8018_MS809.wav 25 | R8009_M8019_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8019_MS810.wav 26 | R8009_M8020_MS810 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/R8009_M8020_MS810.wav 27 | 20200706_L_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200706_L_R001S01C01.wav 28 | 20200708_L_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200708_L_R002S07C01.wav 29 | 20200709_L_R002S08C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200709_L_R002S08C01.wav 30 | 20200616_M_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200616_M_R001S01C01.wav 31 | 20200715_M_R002S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200715_M_R002S06C01.wav 32 | 20200620_M_R002S08C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200620_M_R002S08C01.wav 33 | 20200620_M_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200620_M_R002S07C01.wav 34 | 20200710_M_R002S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200710_M_R002S03C01.wav 35 | 20200710_M_R002S07C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200710_M_R002S07C01.wav 36 | 20200712_M_R002S05C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200712_M_R002S05C01.wav 37 | 20200704_M_R002S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200704_M_R002S06C01.wav 38 | 20200622_M_R002S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200622_M_R002S02C01.wav 39 | 20200715_M_R002S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200715_M_R002S03C01.wav 40 | 20200623_S_R001S06C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200623_S_R001S06C01.wav 41 | 20200805_S_R001S01C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200805_S_R001S01C01.wav 42 | 20200701_S_R001S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200701_S_R001S03C01.wav 43 | 20200805_S_R001S03C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200805_S_R001S03C01.wav 44 | 20200702_S_R001S02C01 /YOUR_PATH/AMI_AliMeeting_AISHELL4/wavs/dev/20200702_S_R001S02C01.wav 45 | -------------------------------------------------------------------------------- /recipes/diar_ssl/data/NSF/dev_sc/all.uem: -------------------------------------------------------------------------------- 1 | S30500101 1 0.000 380.935 2 | S30500201 1 0.000 380.935 3 | S30500211 1 0.000 380.935 4 | S30502101 1 0.000 536.593 5 | S30502201 1 0.000 536.593 6 | S30502211 1 0.000 536.593 7 | S30503101 1 0.000 396.243 8 | S30503201 1 0.000 396.243 9 | S30503211 1 0.000 396.243 10 | S30504101 1 0.000 410.198 11 | S30504201 1 0.000 410.198 12 | S30504211 1 0.000 410.198 13 | S30505101 1 0.000 341.835 14 | S30505201 1 0.000 341.835 15 | S30505211 1 0.000 341.835 16 | S30508101 1 0.000 357.795 17 | S30508201 1 0.000 357.795 18 | S30508211 1 0.000 357.795 19 | S30509101 1 0.000 371.505 20 | S30509201 1 0.000 371.505 21 | S30509211 1 0.000 371.505 22 | S30519101 1 0.000 291.910 23 | S30519201 1 0.000 291.910 24 | S30519211 1 0.000 291.910 25 | S30520101 1 0.000 392.993 26 | S30520201 1 0.000 392.993 27 | S30520211 1 0.000 392.993 28 | S30522101 1 0.000 358.823 29 | S30522201 1 0.000 358.823 30 | S30522211 1 0.000 358.823 31 | S30552101 1 0.000 318.121 32 | S30552201 1 0.000 318.121 33 | S30552211 1 0.000 318.121 34 | S30552301 1 0.000 318.121 35 | S30560101 1 0.000 298.760 36 | S30560201 1 0.000 298.760 37 | S30560211 1 0.000 298.760 38 | S30560301 1 0.000 298.760 39 | S30563101 1 0.000 376.491 40 | S30563201 1 0.000 376.491 41 | S30563211 1 0.000 376.491 42 | S30563301 1 0.000 376.491 43 | S30569101 1 0.000 320.840 44 | S30569201 1 0.000 320.840 45 | S30569211 1 0.000 320.840 46 | S30569301 1 0.000 320.840 47 | S30570101 1 0.000 358.233 48 | S30570201 1 0.000 358.233 49 | S30570211 1 0.000 358.233 50 | S30570301 1 0.000 358.233 51 | S30601101 1 0.000 362.679 52 | S30601201 1 0.000 362.679 53 | S30601211 1 0.000 362.679 54 | S30601301 1 0.000 362.679 55 | S30603101 1 0.000 414.218 56 | S30603201 1 0.000 414.218 57 | S30603211 1 0.000 414.218 58 | S30603301 1 0.000 414.218 59 | S30605101 1 0.000 348.659 60 | S30605201 1 0.000 348.659 61 | S30605211 1 0.000 348.659 62 | S30605301 1 0.000 348.659 63 | S30609101 1 0.000 367.714 64 | S30609201 1 0.000 367.714 65 | S30609211 1 0.000 367.714 66 | S30610101 1 0.000 492.053 67 | S30610201 1 0.000 492.053 68 | S30610211 1 0.000 492.053 69 | S30633101 1 0.000 373.159 70 | S30633201 1 0.000 373.159 71 | S30633211 1 0.000 373.159 72 | S30633301 1 0.000 373.159 73 | S30636101 1 0.000 404.768 74 | S30636201 1 0.000 404.768 75 | S30636211 1 0.000 404.768 76 | S30636301 1 0.000 404.768 77 | S30639101 1 0.000 362.749 78 | S30639201 1 0.000 362.749 79 | S30639211 1 0.000 362.749 80 | S30639301 1 0.000 362.749 81 | S30716101 1 0.000 251.385 82 | S30716201 1 0.000 251.385 83 | S30716211 1 0.000 251.385 84 | S30716301 1 0.000 251.385 85 | S30718101 1 0.000 356.323 86 | S30718201 1 0.000 356.323 87 | S30718211 1 0.000 356.323 88 | S30718301 1 0.000 356.323 89 | S30719101 1 0.000 406.550 90 | S30719201 1 0.000 406.550 91 | S30719211 1 0.000 406.550 92 | S30719301 1 0.000 406.550 93 | S30721201 1 0.000 367.565 94 | S30721211 1 0.000 367.565 95 | S30721301 1 0.000 367.565 96 | S30724201 1 0.000 378.140 97 | S30724211 1 0.000 378.140 98 | S30724301 1 0.000 378.140 99 | S30725201 1 0.000 391.870 100 | S30725211 1 0.000 391.870 101 | S30725301 1 0.000 391.870 102 | S30752101 1 0.000 385.781 103 | S30752201 1 0.000 385.781 104 | S30752211 1 0.000 385.781 105 | S30752301 1 0.000 385.781 106 | S30754101 1 0.000 444.688 107 | S30754201 1 0.000 444.688 108 | S30754211 1 0.000 444.688 109 | S30754301 1 0.000 444.688 110 | S30773101 1 0.000 420.119 111 | S30773201 1 0.000 420.119 112 | S30773211 1 0.000 420.119 113 | S30773301 1 0.000 420.119 114 | S30774101 1 0.000 372.440 115 | S30774201 1 0.000 372.440 116 | S30774211 1 0.000 372.440 117 | S30774301 1 0.000 372.440 118 | -------------------------------------------------------------------------------- /pyannote-audio/.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report a bug in pyannote.audio 3 | body: 4 | 5 | - type: markdown 6 | attributes: 7 | value: | 8 | When reporting bugs, please follow the guidelines in this template. This helps identify the problem precisely and thus enables contributors to fix it faster. 9 | - Write a descriptive issue title above. 10 | - The golden rule is to **always open *one* issue for *one* bug**. If you notice several bugs and want to report them, make sure to create one new issue for each of them. 11 | - Search [open](https://github.com/pyannote/pyannote-audio/issues) and [closed](https://github.com/pyannote/pyannote-audio/issues?q=is%3Aissue+is%3Aclosed) issues to ensure it has not already been reported. If you don't find a relevant match or if you're unsure, don't hesitate to **open a new issue**. The bugsquad will handle it from there if it's a duplicate. 12 | - Please always check if your issue is reproducible in the latest version – it may already have been fixed! 13 | - If you use a custom build, please test if your issue is reproducible in official releases too. 14 | 15 | - type: textarea 16 | attributes: 17 | label: Tested versions 18 | description: | 19 | To properly fix a bug, we need to identify if the bug was recently introduced in the engine, or if it was always present. 20 | - Please specify the pyannote.audio version you found the issue in, including the **Git commit hash** if using a development build. 21 | - If you can, **please test earlier pyannote.audio versions** and, if applicable, newer versions (development branch). Mention whether the bug is reproducible or not in the versions you tested. 22 | - The aim is for us to identify whether a bug is a **regression**, i.e. an issue that didn't exist in a previous version, but was introduced later on, breaking existing functionality. For example, if a bug is reproducible in 3.2 but not in 3.0, we would like you to test intermediate 3.1 to find which version is the first one where the issue can be reproduced. 23 | placeholder: | 24 | - Reproducible in: 3.1, 3.2, and later 25 | - Not reproducible in: 3.0 26 | validations: 27 | required: true 28 | 29 | - type: input 30 | attributes: 31 | label: System information 32 | description: | 33 | - Specify the OS version, and when relevant hardware information. 34 | - For issues that are likely OS-specific and/or GPU-related, please specify the GPU model and architecture. 35 | - **Bug reports not including the required information may be closed at the maintainers' discretion.** If in doubt, always include all the requested information; it's better to include too much information than not enough information. 36 | placeholder: macOS 13.6 - pyannote.audio 3.1.1 - M1 Pro 37 | validations: 38 | required: true 39 | 40 | - type: textarea 41 | attributes: 42 | label: Issue description 43 | description: | 44 | Describe your issue briefly. What doesn't work, and how do you expect it to work instead? 45 | You can include audio, images or videos with drag and drop, and format code blocks or logs with ``` tags. 46 | validations: 47 | required: true 48 | 49 | - type: input 50 | attributes: 51 | label: Minimal reproduction example (MRE) 52 | description: | 53 | Having reproducible issues is a prerequisite for contributors to be able to solve them. 54 | Include a link to minimal reproduction example using [this Google Colab notebook](https://colab.research.google.com/github/pyannote/pyannote-audio/blob/develop/tutorials/MRE_template.ipynb) as a starting point. 55 | validations: 56 | required: true 57 | -------------------------------------------------------------------------------- /pyannote-audio/FAQ.md: -------------------------------------------------------------------------------- 1 | 2 | # Frequently Asked Questions 3 | - [Can I apply pretrained pipelines on audio already loaded in memory?](#can-i-apply-pretrained-pipelines-on-audio-already-loaded-in-memory) 4 | - [Can I use gated models (and pipelines) offline?](#can-i-use-gated-models-(and-pipelines)-offline) 5 | - [Does pyannote support streaming speaker diarization?](#does-pyannote-support-streaming-speaker-diarization) 6 | - [How can I improve performance?](#how-can-i-improve-performance) 7 | - [How does one spell and pronounce pyannote.audio?](#how-does-one-spell-and-pronounce-pyannoteaudio) 8 | 9 | 10 | ## Can I apply pretrained pipelines on audio already loaded in memory? 11 | 12 | Yes: read [this tutorial](tutorials/applying_a_pipeline.ipynb) until the end. 13 | 14 | 15 | ## Can I use gated models (and pipelines) offline? 16 | 17 | **Short answer**: yes, see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines. 18 | 19 | **Long answer**: gating models and pipelines allows [me](https://herve.niderb.fr) to know a bit more about `pyannote.audio` user base and eventually help me write grant proposals to make `pyannote.audio` even better. So, please fill gating forms as precisely as possible. 20 | 21 | For instance, before gating `pyannote/speaker-diarization`, I had no idea that so many people were relying on it in production. Hint: sponsors are more than welcome! Maintaining open source libraries is time consuming. 22 | 23 | That being said, this whole authentication process does not prevent you from using official `pyannote.audio` models offline (i.e. without going through the authentication process in every `docker run ...` or whatever you are using in production): see [this tutorial](tutorials/applying_a_model.ipynb) for models and [that one](tutorials/applying_a_pipeline.ipynb) for pipelines. 24 | 25 | 26 | ## Does pyannote support streaming speaker diarization? 27 | 28 | **Short answer:** not out of the box, no. 29 | 30 | **Long answer:** [I](https://herve.niderb.fr) am looking for sponsors to add this feature. In the meantime, [`diart`](https://github.com/juanmc2005/StreamingSpeakerDiarization) is the closest you can get from a streaming `pyannote.audio`. You might also be interested in [this blog post](https://herve.niderb.fr/fastpages/2021/08/05/Streaming-voice-activity-detection-with-pyannote.html) about streaming voice activity detection based on `pyannote.audio`. 31 | 32 | 33 | ## How can I improve performance? 34 | 35 | **Long answer:** 36 | 37 | 1. Manually annotate dozens of conversations as precisely as possible. 38 | 2. Separate them into train (80%), development (10%) and test (10%) subsets. 39 | 3. Setup the data for use with [`pyannote.database`](https://github.com/pyannote/pyannote-database#speaker-diarization). 40 | 4. Follow [this recipe](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb). 41 | 5. Enjoy. 42 | 43 | **Also:** [I am available](https://herve.niderb.fr) for contracting to help you with that. 44 | 45 | 46 | ## How does one spell and pronounce pyannote.audio? 47 | 48 | 📝 Written in lower case: `pyannote.audio` (or `pyannote` if you are lazy). Not `PyAnnote` nor `PyAnnotate` (sic). 49 | 📢 Pronounced like the french verb `pianoter`. `pi` like in `pi`ano, not `py` like in `py`thon. 50 | 🎹 `pianoter` means to play the piano (hence the logo 🤯). 51 | 52 |
53 | 54 | Generated by [FAQtory](https://github.com/willmcgugan/faqtory) 55 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/utils/reproducibility.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2023- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # Context: https://github.com/pyannote/pyannote-audio/issues/1370 24 | 25 | import warnings 26 | 27 | import torch 28 | 29 | 30 | class ReproducibilityError(Exception): 31 | ... 32 | 33 | 34 | class ReproducibilityWarning(UserWarning): 35 | ... 36 | 37 | 38 | def raise_reproducibility(device: torch.device): 39 | if (device.type == "cuda") and ( 40 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32 41 | ): 42 | raise ReproducibilityError( 43 | "Please disable TensorFloat-32 (TF32) by calling\n" 44 | " >>> import torch\n" 45 | " >>> torch.backends.cuda.matmul.allow_tf32 = False\n" 46 | " >>> torch.backends.cudnn.allow_tf32 = False\n" 47 | "or you might face reproducibility issues and obtain lower accuracy.\n" 48 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details." 49 | ) 50 | 51 | 52 | def warn_reproducibility(device: torch.device): 53 | if (device.type == "cuda") and ( 54 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32 55 | ): 56 | warnings.warn( 57 | ReproducibilityWarning( 58 | "Please disable TensorFloat-32 (TF32) by calling\n" 59 | " >>> import torch\n" 60 | " >>> torch.backends.cuda.matmul.allow_tf32 = False\n" 61 | " >>> torch.backends.cudnn.allow_tf32 = False\n" 62 | "or you might face reproducibility issues and obtain lower accuracy.\n" 63 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details." 64 | ) 65 | ) 66 | 67 | 68 | def fix_reproducibility(device: torch.device): 69 | if (device.type == "cuda") and ( 70 | torch.backends.cuda.matmul.allow_tf32 or torch.backends.cudnn.allow_tf32 71 | ): 72 | torch.backends.cuda.matmul.allow_tf32 = False 73 | torch.backends.cudnn.allow_tf32 = False 74 | warnings.warn( 75 | ReproducibilityWarning( 76 | "TensorFloat-32 (TF32) has been disabled as it might lead to reproducibility issues and lower accuracy.\n" 77 | "It can be re-enabled by calling\n" 78 | " >>> import torch\n" 79 | " >>> torch.backends.cuda.matmul.allow_tf32 = True\n" 80 | " >>> torch.backends.cudnn.allow_tf32 = True\n" 81 | "See https://github.com/pyannote/pyannote-audio/issues/1370 for more details.\n" 82 | ) 83 | ) 84 | -------------------------------------------------------------------------------- /recipes/diar_ssl/run_stage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Licensed under the MIT license. 4 | # Copyright 2024 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) 5 | 6 | set -eu 7 | ulimit -n 2048 8 | 9 | # general setup 10 | stage=1 11 | recipe_root=/YOUR_PATH/DiariZen/recipes/diar_ssl 12 | exp_root=$recipe_root/exp 13 | conf_dir=$recipe_root/conf 14 | 15 | # training setup 16 | use_dual_opt=true # true for wavlm_updated_conformer.toml; false for the others 17 | train_conf=$conf_dir/wavlm_updated_conformer.toml 18 | # train_conf=$conf_dir/wavlm_frozen_conformer.toml 19 | # train_conf=$conf_dir/fbank_conformer.toml 20 | # train_conf=$conf_dir/pyannote_baseline.toml 21 | 22 | conf_name=`ls $train_conf | awk -F '/' '{print $NF}' | awk -F '.' '{print $1}'` 23 | 24 | # inference setup 25 | dtype=test 26 | data_dir=$recipe_root/data/AMI_AliMeeting_AISHELL4 27 | seg_duration=8 28 | 29 | # clustering setup 30 | clustering_method=AgglomerativeClustering 31 | ahc_threshold=0.70 32 | min_cluster_size=30 33 | infer_affix=_constrained_AHC_thres_${ahc_threshold}_mcs_${min_cluster_size} 34 | 35 | avg_ckpt_num=5 36 | val_metric=Loss # Loss or DER 37 | val_mode=best # [prev, best, center] 38 | 39 | # scoring setup 40 | collar=0 41 | REF_DIR=$data_dir 42 | dscore_dir=/YOUR_PATH/DiariZen/dscore 43 | 44 | # ======================================= 45 | # ======================================= 46 | if [ $stage -le 1 ]; then 47 | if (! $use_dual_opt); then 48 | echo "stage1: use single-opt for model training..." 49 | conda activate diarizen && CUDA_VISIBLE_DEVICES="0,1" accelerate launch \ 50 | --num_processes 2 --main_process_port 1134 \ 51 | run_single_opt.py -C $train_conf -M validate 52 | else 53 | echo "stage1: use dual-opt for model training..." 54 | conda activate diarizen && CUDA_VISIBLE_DEVICES="0,1,2,3" accelerate launch \ 55 | --num_processes 4 --main_process_port 1134 \ 56 | run_dual_opt.py -C $train_conf -M train 57 | fi 58 | fi 59 | 60 | diarization_dir=$exp_root/$conf_name # can be replaced by our pre-trained models, e.g. diarization_dir=/YOUR_PATH/checkpoints/wavlm_updated_conformer 61 | config_dir=`ls $diarization_dir/*.toml | sort -r | head -n 1` 62 | embedding_model=/YOUR_PATH/pretrained/pyannote3/wespeaker-voxceleb-resnet34-LM/pytorch_model.bin # it's necessary to have "pyannote" in your directory path 63 | 64 | if [ $stage -le 2 ]; then 65 | echo "stage2: model inference..." 66 | export CUDA_VISIBLE_DEVICES=0 67 | 68 | train_log=`du -h $diarization_dir/*.log | sort -rh | head -n 1 | awk '{print $NF}'` 69 | cat $train_log | grep 'Loss/DER' | awk -F ']:' '{print $NF}' > $diarization_dir/val_metric_summary.lst 70 | 71 | for dset in AMI AliMeeting AISHELL4; do 72 | conda activate diarizen && python infer_avg.py -C $config_dir \ 73 | -i ${data_dir}/${dtype}/${dset}/wav.scp \ 74 | -o ${diarization_dir}/infer$infer_affix/metric_${val_metric}_${val_mode}/avg_ckpt${avg_ckpt_num}/${dtype}/${dset} \ 75 | --embedding_model $embedding_model \ 76 | --avg_ckpt_num $avg_ckpt_num \ 77 | --val_metric $val_metric \ 78 | --val_mode $val_mode \ 79 | --val_metric_summary $diarization_dir/val_metric_summary.lst \ 80 | --seg_duration $seg_duration \ 81 | --clustering_method $clustering_method \ 82 | --ahc_threshold $ahc_threshold \ 83 | --min_cluster_size $min_cluster_size 84 | 85 | echo "stage3: scoring..." 86 | SYS_DIR=${diarization_dir}/infer$infer_affix/metric_${val_metric}_${val_mode}/avg_ckpt${avg_ckpt_num} 87 | OUT_DIR=${SYS_DIR}/${dtype}/${dset} 88 | conda activate diarizen && python ${dscore_dir}/score.py \ 89 | -r ${REF_DIR}/${dtype}/${dset}/rttm \ 90 | -s $OUT_DIR/*.rttm --collar ${collar} \ 91 | > $OUT_DIR/result_collar${collar} 92 | done 93 | fi 94 | -------------------------------------------------------------------------------- /pyannote-audio/pyannote/audio/cli/evaluate.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2022- CNRS 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | 24 | from typing import Optional 25 | 26 | import hydra 27 | from omegaconf import DictConfig 28 | from pyannote.database import FileFinder, ProtocolFile, registry 29 | from rich.progress import Progress 30 | 31 | from pyannote.audio import Inference, Model 32 | from pyannote.audio.pipelines.utils import get_devices 33 | from pyannote.audio.utils.metric import DiscreteDiarizationErrorRate 34 | from pyannote.audio.utils.signal import binarize 35 | 36 | 37 | @hydra.main(config_path="evaluate_config", config_name="config") 38 | def evaluate(cfg: DictConfig) -> Optional[float]: 39 | 40 | # load pretrained model 41 | (device,) = get_devices(needs=1) 42 | model = Model.from_pretrained(cfg.model, device=device) 43 | 44 | # load databases into registry if it was specified 45 | if "registry" in cfg: 46 | for database_yml in cfg.registry.split(","): 47 | registry.load_database(database_yml) 48 | 49 | # load evaluation files 50 | protocol = registry.get_protocol( 51 | cfg.protocol, preprocessors={"audio": FileFinder()} 52 | ) 53 | 54 | files = list(getattr(protocol, cfg.subset)()) 55 | 56 | # load evaluation metric 57 | metric = DiscreteDiarizationErrorRate() 58 | 59 | with Progress() as progress: 60 | 61 | main_task = progress.add_task(protocol.name, total=len(files)) 62 | file_task = progress.add_task("Processing", total=1.0) 63 | 64 | def progress_hook(completed: Optional[int] = None, total: Optional[int] = None): 65 | progress.update(file_task, completed=completed / total) 66 | 67 | inference = Inference(model, device=device) 68 | warm_up = cfg.warm_up / inference.duration 69 | 70 | def hypothesis(file: ProtocolFile): 71 | return Inference.trim( 72 | binarize(inference(file, hook=progress_hook)), 73 | warm_up=(warm_up, warm_up), 74 | ) 75 | 76 | for file in files: 77 | progress.update(file_task, description=file["uri"]) 78 | reference = file["annotation"] 79 | uem = file["annotated"] 80 | _ = metric(reference, hypothesis(file), uem=uem) 81 | progress.advance(main_task) 82 | 83 | report = metric.report(display=False) 84 | 85 | with open("report.txt", "w") as f: 86 | 87 | f.write(f"# Model: {cfg.model}\n") 88 | f.write(f"# Protocol: {protocol.name}\n") 89 | f.write(f"# Subset: {cfg.subset}\n") 90 | f.write("\n") 91 | report = report.to_string( 92 | index=True, 93 | sparsify=False, 94 | justify="right", 95 | float_format=lambda f: "{0:.2f}".format(f), 96 | ) 97 | f.write(f"{report}") 98 | 99 | 100 | if __name__ == "__main__": 101 | evaluate() 102 | -------------------------------------------------------------------------------- /diarizen/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the MIT license. 2 | # Copyright 2022 Brno University of Technology (author: Federico Landini, landini@fit.vut.cz) 3 | # Copyright 2024 Brno University of Technology (author: Jiangyu Han, ihan@fit.vut.cz) 4 | 5 | import os 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import copy 12 | 13 | from typing import List, Dict 14 | 15 | 16 | def average_checkpoints( 17 | model: nn.Module, 18 | checkpoint_list: str, 19 | ) -> nn.Module: 20 | states_dict_list = [] 21 | for ckpt_data in checkpoint_list: 22 | ckpt_path = ckpt_data['bin_path'] 23 | copy_model = copy.deepcopy(model) 24 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 25 | copy_model.load_state_dict(checkpoint) 26 | states_dict_list.append(copy_model.state_dict()) 27 | avg_state_dict = average_states(states_dict_list, torch.device('cpu')) 28 | avg_model = copy.deepcopy(model) 29 | avg_model.load_state_dict(avg_state_dict) 30 | return avg_model 31 | 32 | def average_states( 33 | states_list: List[Dict[str, torch.Tensor]], 34 | device: torch.device, 35 | ) -> List[Dict[str, torch.Tensor]]: 36 | qty = len(states_list) 37 | avg_state = states_list[0] 38 | for i in range(1, qty): 39 | for key in avg_state: 40 | avg_state[key] += states_list[i][key].to(device) 41 | for key in avg_state: 42 | avg_state[key] = avg_state[key] / qty 43 | return avg_state 44 | 45 | def load_metric_summary(metric_file, ckpt_path): 46 | with open(metric_file, "r") as f: 47 | lines = f.readlines() 48 | out_lst = [] 49 | for line in lines: 50 | assert "Validation Loss/DER" in line 51 | epoch = line.split()[4].split(':')[0] 52 | Loss, DER = line.split()[-3], line.split()[-1] 53 | bin_path = f"epoch_{str(epoch).zfill(4)}/pytorch_model.bin" 54 | out_lst.append({ 55 | 'epoch': int(epoch), 56 | 'bin_path': ckpt_path / bin_path, 57 | 'Loss': float(Loss), 58 | 'DER': float(DER) 59 | }) 60 | return out_lst 61 | 62 | def average_ckpt(ckpt_dir, model, wavlm_only=False, val_metric='Loss', avg_ckpt_num=5, val_mode="best"): 63 | if os.path.isfile(ckpt_dir): 64 | ckpt_loaded = torch.load(ckpt_dir, map_location=torch.device('cpu')) 65 | if not wavlm_only: 66 | print(f"No model averaging | Fine-tune model from: {ckpt_dir}") 67 | model.load_state_dict(ckpt_loaded, strict=True) 68 | else: 69 | print(f"Only initialize wavlm model from: {ckpt_dir}") 70 | wavlm_prefix = "wavlm_model." 71 | wavlm_state_dict = { 72 | k[len(wavlm_prefix):]: v 73 | for k, v in ckpt_loaded.items() 74 | if k.startswith(wavlm_prefix) 75 | } 76 | model.wavlm_model.load_state_dict(wavlm_state_dict, strict=False) 77 | return model 78 | 79 | if 'checkpoints/epoch_' in ckpt_dir: 80 | print(f"No model averaging | Fine-tune model from certain epoch: {ckpt_dir.split('/')[-1]}") 81 | ckpt_loaded = torch.load(os.path.join(ckpt_dir, 'pytorch_model.bin'), map_location=torch.device('cpu')) 82 | model.load_state_dict(ckpt_loaded) 83 | return model 84 | 85 | assert val_metric == "Loss" and val_mode == "best" 86 | print(f'averaging best {avg_ckpt_num} checkpoints to the converged moment...') 87 | 88 | ckpt_dir = Path(ckpt_dir).expanduser().absolute() 89 | ckpt_path = ckpt_dir / 'checkpoints' 90 | val_metric_path = ckpt_dir / 'val_metric_summary.lst' 91 | 92 | val_metric_lst = load_metric_summary(val_metric_path, ckpt_path) 93 | val_metric_lst_sorted = sorted(val_metric_lst, key=lambda i: i[val_metric]) 94 | best_val_metric_idx = val_metric_lst.index(val_metric_lst_sorted[0]) 95 | val_metric_lst_out = val_metric_lst[ 96 | best_val_metric_idx - avg_ckpt_num + 1 : 97 | best_val_metric_idx + 1 98 | ] 99 | 100 | return average_checkpoints(model, val_metric_lst_out) 101 | --------------------------------------------------------------------------------