├── pyproject.toml ├── src └── diarizers │ ├── models │ ├── __init__.py │ ├── pyannet.py │ └── model.py │ ├── data │ ├── __init__.py │ ├── preprocess.py │ └── speaker_diarization.py │ ├── __init__.py │ ├── dependency_versions_table.py │ ├── utils.py │ └── test.py ├── .gitignore ├── Makefile ├── sanity_checks └── check_preprocessing.py ├── datasets ├── README.md └── spd_datasets.py ├── test_segmentation.py ├── train_segmentation.py ├── utils └── check_dummies.py ├── README.md └── setup.py /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 119 3 | target-version = ['py37'] -------------------------------------------------------------------------------- /src/diarizers/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel, SegmentationModelConfig 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | env 3 | .vscode 4 | 5 | analysis 6 | synthetic 7 | dist 8 | build 9 | checkpoint 10 | -------------------------------------------------------------------------------- /src/diarizers/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import Preprocess 2 | from .speaker_diarization import SpeakerDiarizationDataset 3 | -------------------------------------------------------------------------------- /src/diarizers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1" 2 | 3 | from .data import Preprocess, SpeakerDiarizationDataset 4 | from .models import SegmentationModel, SegmentationModelConfig 5 | from .test import Test 6 | from .utils import DataCollator, Metrics, train_val_test_split 7 | -------------------------------------------------------------------------------- /src/diarizers/dependency_versions_table.py: -------------------------------------------------------------------------------- 1 | # THIS FILE HAS BEEN AUTOGENERATED. To update: 2 | # 1. modify the `_deps` dict in setup.py 3 | # 2. run `make deps_table_update`` 4 | deps = { 5 | "accelerate": "accelerate>=0.14.0", 6 | "torch": "torch>=1.9", 7 | "transformers": "transformers>=4.24.0", 8 | "black": "black==22.8", 9 | "isort": "isort>=5.5.4", 10 | "flake8": "flake8>=3.8.3", 11 | "numpy": "numpy", 12 | "filelock": "filelock", 13 | "importlib_metadata": "importlib_metadata", 14 | "datasets": "datasets", 15 | "pyannote.audio": "pyannote.audio", 16 | } 17 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup 2 | 3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) 4 | export PYTHONPATH = src 5 | 6 | check_dirs := src 7 | 8 | # Update src/diarizers/dependency_versions_table.py 9 | 10 | deps_table_update: 11 | @python setup.py deps_table_update 12 | 13 | deps_table_check_updated: 14 | @md5sum src/diarizers/dependency_versions_table.py > md5sum.saved 15 | @python setup.py deps_table_update 16 | @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1) 17 | @rm md5sum.saved 18 | 19 | # autogenerating code 20 | 21 | autogenerate_code: deps_table_update 22 | 23 | # Check that the repo is in a good state 24 | 25 | # this target runs checks on all files 26 | 27 | quality: 28 | black --check --preview $(check_dirs) 29 | isort --check-only $(check_dirs) 30 | flake8 $(check_dirs) 31 | 32 | # Format source code automatically and check is there are any problems left that need manual fixing 33 | 34 | # this target runs checks on all files and potentially modifies some of them 35 | 36 | style: 37 | black --preview $(check_dirs) 38 | isort $(check_dirs) 39 | ${MAKE} autogenerate_code 40 | 41 | fix-copies: 42 | python utils/check_dummies.py --fix_and_overwrite -------------------------------------------------------------------------------- /sanity_checks/check_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyannote.audio import Model 3 | from pyannote.audio.tasks import SpeakerDiarization 4 | from pyannote.database import registry 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | from datasets import load_dataset 8 | 9 | 10 | def get_chunk_from_pyannote(seg_task, file_id, start_time, duration): 11 | 12 | seg_task.prepare_data() 13 | seg_task.setup() 14 | 15 | chunk = seg_task.prepare_chunk(file_id, start_time, duration) 16 | 17 | return chunk 18 | 19 | 20 | def sanity_checks(): 21 | 22 | # Extract 10 second audio from meeting EN2001a (= file_id 124). 23 | # We choose start_time = 3.34 to match with the first 10 seconds of audio from the synthetic AMI. 24 | synthetic_ami_chunk = synthetic_ami_dataset_processed["train"][0] 25 | waveform_synthetic = np.array(synthetic_ami_chunk["waveforms"]) 26 | synthetic_labels = np.array(synthetic_ami_chunk["labels"]) 27 | index_positions = np.nonzero(waveform_synthetic) 28 | 29 | real_ami_chunk = get_chunk_from_pyannote(seg_task, 124, 3.34, 10) 30 | real_labels = real_ami_chunk["y"].data 31 | waveform_real = np.array(real_ami_chunk["X"][0]) 32 | 33 | waveform_synthetic_without_zeros = waveform_synthetic[index_positions] 34 | waveform_real_without_zeros = waveform_real[index_positions] 35 | 36 | similarity_without_zeros = cosine_similarity([waveform_synthetic_without_zeros], [waveform_real_without_zeros]) 37 | similarity_with_zeros = cosine_similarity([waveform_synthetic], [waveform_real]) 38 | 39 | assert (synthetic_labels == real_labels).all(), "labels are not matching" 40 | assert similarity_without_zeros > 0.95 41 | assert similarity_with_zeros > 0.8 42 | 43 | # We choose start_time = 5.90 to get a sample that doesn't match with the first 10 seconds of audio from the synthetic AMI. 44 | real_ami_chunk = get_chunk_from_pyannote(seg_task, 124, 5.90, 10) 45 | real_labels = real_ami_chunk["y"].data 46 | waveform_real = np.array(real_ami_chunk["X"][0]) 47 | waveform_real_without_zeros = waveform_real[index_positions] 48 | 49 | similarity_without_zeros = cosine_similarity([waveform_synthetic_without_zeros], [waveform_real_without_zeros]) 50 | similarity_with_zeros = cosine_similarity([waveform_synthetic], [waveform_real]) 51 | 52 | assert (synthetic_labels == real_labels).all() == False 53 | assert similarity_without_zeros < 0.01 54 | assert similarity_with_zeros < 0.01 55 | 56 | 57 | if __name__ == "__main__": 58 | 59 | registry.load_database("/home/kamil/datasets/AMI-diarization-setup/pyannote/database.yml") 60 | ami = registry.get_protocol("AMI.SpeakerDiarization.only_words") 61 | 62 | seg_task = SpeakerDiarization(ami, duration=10.0, max_speakers_per_chunk=3, max_speakers_per_frame=2) 63 | pretrained = Model.from_pretrained("pyannote/segmentation-3.0", use_auth_token=True) 64 | seg_task.model = pretrained 65 | 66 | synthetic_ami_dataset_processed = load_dataset("kamilakesbi/real_ami_processed_sc") 67 | 68 | sanity_checks() 69 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Speaker diarization datasets 2 | 3 | ## Add any speaker diarization dataset to the hub 4 | 5 | General steps to add a Speaker diarization dataset with to the hub: 6 | 7 | 1. Prepare a folder containing audios and annotations files , which should be organised like this: 8 | 9 | ``` 10 | dataset_folder 11 | ├── audio 12 | │ ├── file_1.mp3 13 | │ ├── file_2.mp3 14 | │ └── file_3.mp3 15 | ├── annotations 16 | │ ├── file_1.rttm 17 | │ ├── file_2.rttm 18 | │ └── file_3.rttm 19 | ``` 20 | 21 | 22 | 2. Get dictionnaries with the following structure: 23 | 24 | ``` 25 | annotations_files = { 26 | "subset1": [list of annotations_files in subset1], 27 | "subset2": [list of annotations_files in subset2], 28 | } 29 | 30 | audio_files = { 31 | "subset1": [list of annotations_files in subset1], 32 | "subset2": [list of annotations_files in subset2], 33 | } 34 | ``` 35 | 36 | Here, each subset will correspond in a Hugging Face dataset subset. 37 | 38 | 3. Use SpeakerDiarization module from `diarizers` to obtain your Hugging Face dataset: 39 | 40 | ``` 41 | from diarizers import SpeakerDiarizationDataset 42 | 43 | dataset = SpeakerDiarizationDataset(audio_files, annotations_files).construct_dataset() 44 | ``` 45 | 46 | Note: This module can currently be used on RTTM format annotation files, but may need to be adapted for other formats. 47 | 48 | ## Current datasets in diarizers-community 49 | 50 | We explain the scripts we used to add the various datasets present in the [diarizers-community](https://huggingface.co/diarizers-community): 51 | 52 | #### AMI IHM AND SDM: 53 | 54 | ``` 55 | git clone https://github.com/pyannote/AMI-diarization-setup.git 56 | cd /AMI-diarization-setup/pyannote/ 57 | sh download_ami.sh 58 | sh download_ami_sdm.sh 59 | ``` 60 | 61 | #### CALLHOME: 62 | 63 | Download for each language (example here for Japanese): 64 | 65 | ``` 66 | wget https://ca.talkbank.org/data/CallHome/jpn.zip 67 | wget -r -np -nH --cut-dirs=2 -R index.html* https://media.talkbank.org/ca/CallHome/jpn/ 68 | unzip jpn.zip 69 | ``` 70 | 71 | #### VOXCONVERSE: 72 | 73 | Download the RTTM files: 74 | 75 | ``` 76 | git clone git@github.com:joonson/voxconverse.git 77 | ``` 78 | 79 | Download the audio files: 80 | 81 | ``` 82 | wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_dev_wav.zip 83 | unzip voxconverse_dev_wav.zip 84 | 85 | wget https://www.robots.ox.ac.uk/~vgg/data/voxconverse/data/voxconverse_test_wav.zip 86 | unzip voxconverse_test_wav.zip 87 | ``` 88 | 89 | #### SIMSAMU: 90 | 91 | The Simsamu dataset is based on this [Hugging Face dataset](https://huggingface.co/datasets/medkit/simsamu): 92 | 93 | ``` 94 | git lfs install 95 | git clone git@hf.co:datasets/medkit/simsamu 96 | ``` 97 | 98 | #### Push to hub: 99 | 100 | We pushed each of these datasets using a `spd_datasets.py` and the following script: 101 | 102 | ``` 103 | python3 -m spd_datasets \ 104 | --dataset=callhome \ 105 | --path_to_callhome=/path_to_callhome \ 106 | --push_to_hub=True \ 107 | --hub_repository=diarizers-community/callhome \ 108 | ``` 109 | 110 | -------------------------------------------------------------------------------- /test_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pyannote.audio import Model 4 | from datasets import load_dataset 5 | from diarizers import SegmentationModel, Test, train_val_test_split 6 | from dataclasses import dataclass, field 7 | from transformers import HfArgumentParser 8 | from typing import Optional 9 | 10 | @dataclass 11 | class DataTrainingArguments: 12 | """ 13 | Arguments pertaining to what data we are going to input our model for training and eval. 14 | 15 | Using `HfArgumentParser` we can turn this class 16 | into argparse arguments to be able to specify them on 17 | the command line. 18 | """ 19 | 20 | dataset_name: str = field( 21 | default=None, 22 | metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 23 | ) 24 | dataset_config_name: str = field( 25 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 26 | ) 27 | 28 | test_split_name: str = field( 29 | default="test", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"} 30 | ) 31 | 32 | split_on_subset: str = field( 33 | default=None, 34 | metadata={"help": "Automatically splits the dataset into train-val-set on a specified subset. Defaults to 'None'"}, 35 | ) 36 | 37 | preprocessing_num_workers: Optional[int] = field( 38 | default=None, 39 | metadata={"help": "The number of processes to use for the preprocessing."}, 40 | ) 41 | 42 | @dataclass 43 | class ModelArguments: 44 | """ 45 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 46 | """ 47 | 48 | model_name_or_path: str = field( 49 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 50 | ) 51 | 52 | cache_dir: Optional[str] = field( 53 | default=None, 54 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 61 | 62 | parser = HfArgumentParser((DataTrainingArguments, ModelArguments)) 63 | data_args, model_args = parser.parse_args_into_dataclasses() 64 | 65 | # Load the dataset: 66 | if str(data_args.dataset_config_name): 67 | dataset = load_dataset( 68 | str(data_args.dataset_name), 69 | str(data_args.dataset_config_name), 70 | num_proc=int(data_args.preprocessing_num_workers) 71 | ) 72 | else: 73 | dataset = load_dataset( 74 | str(data_args.dataset_name), 75 | str(data_args.dataset_config_name), 76 | num_proc=int(data_args.preprocessing_num_workers) 77 | ) 78 | 79 | test_split_name = data_args.test_split_name 80 | if data_args.split_on_subset: 81 | dataset = train_val_test_split(dataset[str(data_args.split_on_subset)]) 82 | test_split_name = 'test' 83 | 84 | test_dataset = dataset[data_args.test_split_name] 85 | 86 | 87 | if model_args.model_name_or_path == "pyannote/segmentation-3.0": 88 | model = Model.from_pretrained(model_args.model_name_or_path, use_auth_token=True) 89 | else: 90 | model = SegmentationModel() 91 | model = model.from_pretrained( 92 | model_args.model_name_or_path, 93 | cache_dir=model_args.cache_dir, 94 | use_auth_token=True 95 | ) 96 | model = model.to_pyannote_model() 97 | 98 | test = Test(test_dataset, model, step=2.5) 99 | metrics = test.compute_metrics() 100 | print(metrics) -------------------------------------------------------------------------------- /src/diarizers/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pyannote.audio.torchmetrics import (DiarizationErrorRate, FalseAlarmRate, 4 | MissedDetectionRate, 5 | SpeakerConfusionRate) 6 | from pyannote.audio.utils.powerset import Powerset 7 | 8 | from datasets import DatasetDict 9 | 10 | 11 | def train_val_test_split(dataset, train_size=0.8, val_size=0.1, test_size=0.1): 12 | dataset_split = dataset.train_test_split(test_size=test_size, seed=42) 13 | train_dataset = dataset_split["train"] 14 | test_dataset = dataset_split["test"] 15 | 16 | dataset = train_dataset.train_test_split(train_size=train_size / (train_size + val_size), seed=42) 17 | train_dataset = dataset["train"] 18 | val_dataset = dataset["test"] 19 | 20 | return DatasetDict( 21 | { 22 | "train": train_dataset, 23 | "validation": val_dataset, 24 | "test": test_dataset, 25 | } 26 | ) 27 | 28 | 29 | class Metrics: 30 | def __init__(self, specifications) -> None: 31 | self.powerset = specifications.powerset 32 | self.classes = specifications.classes 33 | self.powerset_max_classes = specifications.powerset_max_classes 34 | 35 | self.model_powerset = Powerset( 36 | len(self.classes), 37 | self.powerset_max_classes, 38 | ) 39 | 40 | self.metrics = { 41 | "der": DiarizationErrorRate(0.5), 42 | "confusion": SpeakerConfusionRate(0.5), 43 | "missed_detection": MissedDetectionRate(0.5), 44 | "false_alarm": FalseAlarmRate(0.5), 45 | } 46 | 47 | def der_metric(self, eval_pred): 48 | logits, labels = eval_pred 49 | 50 | if self.powerset: 51 | predictions = self.model_powerset.to_multilabel(torch.tensor(logits)) 52 | else: 53 | predictions = torch.tensor(logits) 54 | 55 | labels = torch.tensor(labels) 56 | 57 | predictions = torch.transpose(predictions, 1, 2) 58 | labels = torch.transpose(labels, 1, 2) 59 | 60 | metrics = {"der": 0, "false_alarm": 0, "missed_detection": 0, "confusion": 0} 61 | 62 | metrics["der"] += self.metrics["der"](predictions, labels).cpu().numpy() 63 | metrics["false_alarm"] += self.metrics["false_alarm"](predictions, labels).cpu().numpy() 64 | metrics["missed_detection"] += self.metrics["missed_detection"](predictions, labels).cpu().numpy() 65 | metrics["confusion"] += self.metrics["confusion"](predictions, labels).cpu().numpy() 66 | 67 | return metrics 68 | 69 | 70 | class DataCollator: 71 | """Data collator that will dynamically pad the target labels to have max_speakers_per_chunk""" 72 | 73 | def __init__(self, max_speakers_per_chunk) -> None: 74 | self.max_speakers_per_chunk = max_speakers_per_chunk 75 | 76 | def __call__(self, features): 77 | """_summary_ 78 | 79 | Args: 80 | features (_type_): _description_ 81 | 82 | Returns: 83 | _type_: _description_ 84 | """ 85 | 86 | batch = {} 87 | 88 | speakers = [f["nb_speakers"] for f in features] 89 | labels = [f["labels"] for f in features] 90 | 91 | batch["labels"] = self.pad_targets(labels, speakers) 92 | 93 | batch["waveforms"] = torch.stack([f["waveforms"] for f in features]) 94 | 95 | return batch 96 | 97 | def pad_targets(self, labels, speakers): 98 | """ 99 | labels: 100 | speakers: 101 | 102 | Returns: 103 | _type_: 104 | Collated target tensor of shape (num_frames, self.max_speakers_per_chunk) 105 | If one chunk has more than max_speakers_per_chunk speakers, we keep 106 | the max_speakers_per_chunk most talkative ones. If it has less, we pad with 107 | zeros (artificial inactive speakers). 108 | """ 109 | 110 | targets = [] 111 | 112 | for i in range(len(labels)): 113 | label = speakers[i] 114 | target = labels[i].numpy() 115 | num_speakers = len(label) 116 | 117 | if num_speakers > self.max_speakers_per_chunk: 118 | indices = np.argsort(-np.sum(target, axis=0), axis=0) 119 | target = target[:, indices[: self.max_speakers_per_chunk]] 120 | 121 | elif num_speakers < self.max_speakers_per_chunk: 122 | target = np.pad( 123 | target, 124 | ((0, 0), (0, self.max_speakers_per_chunk - num_speakers)), 125 | mode="constant", 126 | ) 127 | 128 | targets.append(target) 129 | 130 | return torch.from_numpy(np.stack(targets)) 131 | -------------------------------------------------------------------------------- /src/diarizers/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pyannote.audio import Inference 4 | from pyannote.audio.pipelines.utils import get_devices 5 | from pyannote.audio.torchmetrics import (DiarizationErrorRate, FalseAlarmRate, 6 | MissedDetectionRate, 7 | SpeakerConfusionRate) 8 | from pyannote.core import SlidingWindow, SlidingWindowFeature 9 | from tqdm import tqdm 10 | 11 | 12 | class Test: 13 | """Class used to evaluate a SegmentationModel at inference time on a test set.""" 14 | 15 | def __init__(self, test_dataset, model, step=2.5): 16 | """_summary_ 17 | 18 | Args: 19 | test_dataset (_type_): _description_ 20 | model (_type_): _description_ 21 | step (float, optional): _description_. Defaults to 2.5. 22 | """ 23 | 24 | self.test_dataset = test_dataset 25 | self.model = model 26 | (self.device,) = get_devices(needs=1) 27 | self.inference = Inference(self.model, step=step, device=self.device) 28 | 29 | self.sample_rate = test_dataset[0]["audio"]["sampling_rate"] 30 | 31 | # Get the number of frames associated to a chunk: 32 | _, self.num_frames, _ = self.inference.model( 33 | torch.rand((1, int(self.inference.duration * self.sample_rate))).to(self.device) 34 | ).shape 35 | # compute frame resolution: 36 | self.resolution = self.inference.duration / self.num_frames 37 | 38 | self.metrics = { 39 | "der": DiarizationErrorRate(0.5).to(self.device), 40 | "confusion": SpeakerConfusionRate(0.5).to(self.device), 41 | "missed_detection": MissedDetectionRate(0.5).to(self.device), 42 | "false_alarm": FalseAlarmRate(0.5).to(self.device), 43 | } 44 | 45 | def predict(self, file): 46 | """_summary_ 47 | 48 | Args: 49 | file (_type_): _description_ 50 | 51 | Returns: 52 | _type_: _description_ 53 | """ 54 | audio = torch.tensor(file["audio"]["array"]).unsqueeze(0).to(torch.float32).to(self.device) 55 | sample_rate = file["audio"]["sampling_rate"] 56 | 57 | input = {"waveform": audio, "sample_rate": sample_rate} 58 | 59 | prediction = self.inference(input) 60 | 61 | return prediction 62 | 63 | def compute_gt(self, file): 64 | """_summary_ 65 | 66 | Args: 67 | file (_type_): _description_ 68 | 69 | Returns: 70 | _type_: _description_ 71 | """ 72 | 73 | audio = torch.tensor(file["audio"]["array"]).unsqueeze(0).to(torch.float32) 74 | sample_rate = file["audio"]["sampling_rate"] 75 | 76 | audio_duration = len(audio[0]) / sample_rate 77 | num_frames = int(round(audio_duration / self.resolution)) 78 | 79 | labels = list(set(file["speakers"])) 80 | 81 | gt = np.zeros((num_frames, len(labels)), dtype=np.uint8) 82 | 83 | for i in range(len(file["timestamps_start"])): 84 | start = file["timestamps_start"][i] 85 | end = file["timestamps_end"][i] 86 | speaker = file["speakers"][i] 87 | start_frame = int(round(start / self.resolution)) 88 | end_frame = int(round(end / self.resolution)) 89 | speaker_index = labels.index(speaker) 90 | 91 | gt[start_frame:end_frame, speaker_index] += 1 92 | 93 | return gt 94 | 95 | def compute_metrics_on_file(self, file): 96 | """_summary_ 97 | 98 | Args: 99 | file (_type_): _description_ 100 | """ 101 | 102 | gt = self.compute_gt(file) 103 | prediction = self.predict(file) 104 | 105 | sliding_window = SlidingWindow(start=0, step=self.resolution, duration=self.resolution) 106 | labels = list(set(file["speakers"])) 107 | 108 | reference = SlidingWindowFeature(data=gt, labels=labels, sliding_window=sliding_window) 109 | 110 | for window, pred in prediction: 111 | reference_window = reference.crop(window, mode="center") 112 | common_num_frames = min(self.num_frames, reference_window.shape[0]) 113 | 114 | ref_num_frames, ref_num_speakers = reference_window.shape 115 | pred_num_frames, pred_num_speakers = pred.shape 116 | 117 | if pred_num_speakers > ref_num_speakers: 118 | reference_window = np.pad(reference_window, ((0, 0), (0, pred_num_speakers - ref_num_speakers))) 119 | elif ref_num_speakers > pred_num_speakers: 120 | pred = np.pad(pred, ((0, 0), (0, ref_num_speakers - pred_num_speakers))) 121 | 122 | pred = torch.tensor(pred[:common_num_frames]).unsqueeze(0).permute(0, 2, 1).to(self.device) 123 | target = (torch.tensor(reference_window[:common_num_frames]).unsqueeze(0).permute(0, 2, 1)).to(self.device) 124 | 125 | self.metrics["der"](pred, target) 126 | self.metrics["false_alarm"](pred, target) 127 | self.metrics["missed_detection"](pred, target) 128 | self.metrics["confusion"](pred, target) 129 | 130 | def compute_metrics(self): 131 | """_summary_ 132 | 133 | Returns: 134 | _type_: _description_ 135 | """ 136 | 137 | for file in tqdm(self.test_dataset): 138 | self.compute_metrics_on_file(file) 139 | 140 | return { 141 | "der": self.metrics["der"].compute(), 142 | "false_alarm": self.metrics["false_alarm"].compute(), 143 | "missed_detection": self.metrics["missed_detection"].compute(), 144 | "confusion": self.metrics["confusion"].compute(), 145 | } 146 | -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from pyannote.audio import Model 4 | from transformers import Trainer, TrainingArguments, HfArgumentParser 5 | 6 | from datasets import load_dataset 7 | from diarizers import Preprocess, SegmentationModel, DataCollator, Metrics, train_val_test_split 8 | from dataclasses import dataclass, field 9 | 10 | @dataclass 11 | class DataTrainingArguments: 12 | """ 13 | Arguments pertaining to what data we are going to input our model for training and eval. 14 | 15 | Using `HfArgumentParser` we can turn this class 16 | into argparse arguments to be able to specify them on 17 | the command line. 18 | """ 19 | 20 | dataset_name: str = field( 21 | default=None, 22 | metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 23 | ) 24 | dataset_config_name: str = field( 25 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 26 | ) 27 | 28 | train_split_name: str = field( 29 | default="train", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"} 30 | ) 31 | 32 | eval_split_name: str = field( 33 | default="val", metadata={"help": "The name of the training data set split to use (via the datasets library). Defaults to 'val'"} 34 | ) 35 | 36 | split_on_subset: str = field( 37 | default=None, 38 | metadata={"help": "Automatically splits the dataset into train-val-set on a specified subset. Defaults to 'None'"}, 39 | ) 40 | 41 | preprocessing_num_workers: Optional[int] = field( 42 | default=None, 43 | metadata={"help": "The number of processes to use for the preprocessing."}, 44 | ) 45 | 46 | @dataclass 47 | class ModelArguments: 48 | """ 49 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 50 | """ 51 | 52 | model_name_or_path: str = field( 53 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 54 | ) 55 | 56 | cache_dir: Optional[str] = field( 57 | default=None, 58 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 59 | ) 60 | 61 | if __name__ == "__main__": 62 | 63 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 64 | 65 | parser = HfArgumentParser((DataTrainingArguments, ModelArguments, TrainingArguments)) 66 | 67 | data_args, model_args, training_args = parser.parse_args_into_dataclasses() 68 | 69 | # Load the dataset: 70 | if str(data_args.dataset_config_name): 71 | dataset = load_dataset( 72 | str(data_args.dataset_name), 73 | str(data_args.dataset_config_name), 74 | num_proc=int(data_args.preprocessing_num_workers) 75 | ) 76 | else: 77 | dataset = load_dataset( 78 | str(data_args.dataset_name), 79 | str(data_args.dataset_config_name), 80 | num_proc=int(data_args.preprocessing_num_workers) 81 | ) 82 | 83 | train_split_name = data_args.train_split_name 84 | val_split_name = data_args.eval_split_name 85 | 86 | if data_args.split_on_subset: 87 | dataset = train_val_test_split(dataset[str(data_args.split_on_subset)]) 88 | train_split_name = 'train' 89 | val_split_name = 'val' 90 | 91 | pretrained = Model.from_pretrained( 92 | model_args.model_name_or_path, 93 | cache_dir=model_args.cache_dir, 94 | use_auth_token=True 95 | ) 96 | model = SegmentationModel() 97 | model.from_pyannote_model(pretrained) 98 | 99 | preprocessor = Preprocess(model.config) 100 | 101 | if training_args.do_train: 102 | train_set = dataset['train'].map( 103 | lambda file: preprocessor(file, random=False, overlap=0.5), 104 | num_proc=data_args.preprocessing_num_workers, 105 | remove_columns=next(iter(dataset.values())).column_names, 106 | batched=True, 107 | batch_size=1 108 | ).with_format("torch") 109 | 110 | if training_args.do_eval: 111 | val_set = dataset['validation'].map( 112 | lambda file: preprocessor(file, random=False, overlap=0.0), 113 | num_proc=data_args.preprocessing_num_workers, 114 | remove_columns=next(iter(dataset.values())).column_names, 115 | batched=True, 116 | keep_in_memory=True, 117 | batch_size=1 118 | ).with_format('torch') 119 | 120 | # Load metrics: 121 | metrics = Metrics(model.specifications) 122 | 123 | trainer = Trainer( 124 | model=model, 125 | args=training_args, 126 | train_dataset=train_set, 127 | data_collator=DataCollator(max_speakers_per_chunk=model.config.max_speakers_per_chunk), 128 | eval_dataset=val_set, 129 | compute_metrics=metrics.der_metric, 130 | ) 131 | 132 | if training_args.do_eval: 133 | first_eval = trainer.evaluate() 134 | print("Initial metric values: ", first_eval) 135 | if training_args.do_train: 136 | trainer.train() 137 | 138 | # 14. Write Training Stats 139 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "speaker diarization"} 140 | if data_args.dataset_name is not None: 141 | kwargs["dataset_tags"] = data_args.dataset_name 142 | if data_args.dataset_config_name is not None: 143 | kwargs["dataset_args"] = data_args.dataset_config_name 144 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 145 | else: 146 | kwargs["dataset"] = data_args.dataset_name 147 | 148 | if training_args.push_to_hub: 149 | trainer.push_to_hub(**kwargs) 150 | else: 151 | trainer.create_model_card(**kwargs) 152 | -------------------------------------------------------------------------------- /datasets/spd_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import glob 4 | import os 5 | 6 | from pydub import AudioSegment 7 | 8 | from src.diarizers.data.speaker_diarization import SpeakerDiarizationDataset 9 | 10 | 11 | def get_ami_files(path_to_ami, setup="only_words", hm_type="ihm"): 12 | 13 | """_summary_ 14 | 15 | Returns: 16 | _type_: _description_ 17 | """ 18 | assert setup in ["only_words", "mini"] 19 | assert hm_type in ["ihm", "sdm"] 20 | 21 | rttm_files = { 22 | "train": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "train")), 23 | "validation": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "dev")), 24 | "test": glob.glob(path_to_ami + "/AMI-diarization-setup/{}/rttms/{}/*.rttm".format(setup, "test")), 25 | } 26 | 27 | audio_files = { 28 | "train": [], 29 | "validation": [], 30 | "test": [], 31 | } 32 | 33 | for subset in rttm_files: 34 | 35 | rttm_list = copy.deepcopy(rttm_files[subset]) 36 | 37 | for rttm in rttm_list: 38 | meeting = rttm.split("/")[-1].split(".")[0] 39 | if hm_type == "ihm": 40 | path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Mix-Headset.wav".format( 41 | meeting, meeting 42 | ) 43 | if os.path.exists(path): 44 | audio_files[subset].append(path) 45 | else: 46 | rttm_files[subset].remove(rttm) 47 | if hm_type == "sdm": 48 | path = path_to_ami + "/AMI-diarization-setup/pyannote/amicorpus/{}/audio/{}.Array1-01.wav".format( 49 | meeting, meeting 50 | ) 51 | if os.path.exists(path): 52 | audio_files[subset].append(path) 53 | else: 54 | rttm_files[subset].remove(rttm) 55 | 56 | return audio_files, rttm_files 57 | 58 | 59 | def get_callhome_files(path_to_callhome, langage="jpn"): 60 | 61 | audio_files = glob.glob(path_to_callhome + "/callhome/{}/*.mp3".format(langage)) 62 | 63 | audio_files = { 64 | "data": audio_files, 65 | } 66 | cha_files = { 67 | "data": [], 68 | } 69 | 70 | for subset in audio_files: 71 | for cha_path in audio_files[subset]: 72 | file = cha_path.split("/")[-1].split(".")[0] 73 | cha_files[subset].append(path_to_callhome + "/callhome/{}/{}.cha".format(langage, file)) 74 | 75 | return audio_files, cha_files 76 | 77 | 78 | def get_simsamu_files(path_to_simsamu): 79 | 80 | rttm_files = glob.glob(path_to_simsamu + "/simsamu/*/*.rttm") 81 | audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.m4a") 82 | 83 | for file in audio_files: 84 | sound = AudioSegment.from_file(file, format="m4a") 85 | file.split("/") 86 | file_hanlde = sound.export(file.split(".")[0] + ".wav", format="wav") 87 | 88 | audio_files = glob.glob(path_to_simsamu + "/simsamu/*/*.wav") 89 | 90 | audio_files = {"data": audio_files} 91 | 92 | rttm_files = {"data": rttm_files} 93 | 94 | return audio_files, rttm_files 95 | 96 | 97 | def get_voxconverse_files(path_to_voxconverse): 98 | 99 | rttm_files = { 100 | "dev": glob.glob(path_to_voxconverse + "/voxconverse/dev/*.rttm"), 101 | "test": glob.glob(path_to_voxconverse + "/voxconverse/test/*.rttm"), 102 | } 103 | 104 | audio_files = { 105 | "dev": glob.glob(path_to_voxconverse + "/voxconverse/audio/*.wav"), 106 | "test": glob.glob(path_to_voxconverse + "/voxconverse/voxconverse_test_wav/*.wav"), 107 | } 108 | 109 | return audio_files, rttm_files 110 | 111 | 112 | if __name__ == "__main__": 113 | 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--dataset", required=True) 116 | parser.add_argument("--path_to_dataset", required=True) 117 | parser.add_argument("--setup", required=False, default="only_words") 118 | parser.add_argument("--push_to_hub", required=False, default=False) 119 | 120 | parser.add_argument("--hub_repository", required=False) 121 | args = parser.parse_args() 122 | 123 | if args.dataset == "ami": 124 | 125 | audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="ihm") 126 | ami_dataset_ihm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() 127 | if args.push_to_hub == "True": 128 | ami_dataset_ihm.push_to_hub(args.hub_repository, "ihm") 129 | 130 | audio_files, rttm_files = get_ami_files(path_to_ami=args.path_to_dataset, setup=args.setup, hm_type="sdm") 131 | ami_dataset_sdm = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() 132 | if args.push_to_hub == "True": 133 | ami_dataset_sdm.push_to_hub(args.hub_repository, "sdm") 134 | 135 | if args.dataset == "callhome": 136 | 137 | langages = ["eng", "jpn", "spa", "zho", "deu"] 138 | 139 | for langage in langages: 140 | audio_files, cha_files = get_callhome_files(args.path_to_dataset, langage=langage) 141 | dataset = SpeakerDiarizationDataset( 142 | audio_files, cha_files, annotations_type="cha", crop_unannotated_regions=True 143 | ).construct_dataset(num_proc=24) 144 | 145 | if args.push_to_hub == "True": 146 | dataset.push_to_hub(args.hub_repository, str(langage)) 147 | 148 | if args.dataset == "simsamu": 149 | audio_files, rttm_files = get_simsamu_files(args.path_to_dataset) 150 | dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() 151 | 152 | if args.push_to_hub == "True": 153 | dataset.push_to_hub(args.hub_repository) 154 | 155 | if args.dataset == "voxconverse": 156 | audio_files, rttm_files = get_voxconverse_files(args.path_to_dataset) 157 | dataset = SpeakerDiarizationDataset(audio_files, rttm_files).construct_dataset() 158 | print(dataset) 159 | if args.push_to_hub == "True": 160 | dataset.push_to_hub(args.hub_repository) 161 | -------------------------------------------------------------------------------- /utils/check_dummies.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import os 18 | import re 19 | 20 | # All paths are set with the intent you should run this script from the root of the repo with the command 21 | # python utils/check_dummies.py 22 | PATH_TO_DIARIZERS = "src/diarizers" 23 | 24 | # Matches is_xxx_available() 25 | _re_backend = re.compile(r"is\_([a-z_]*)_available\(\)") 26 | # Matches from xxx import bla 27 | _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") 28 | 29 | 30 | DUMMY_CONSTANT = """ 31 | {0} = None 32 | """ 33 | 34 | DUMMY_CLASS = """ 35 | class {0}(metaclass=DummyObject): 36 | _backends = {1} 37 | 38 | def __init__(self, *args, **kwargs): 39 | requires_backends(self, {1}) 40 | 41 | @classmethod 42 | def from_config(cls, *args, **kwargs): 43 | requires_backends(cls, {1}) 44 | 45 | @classmethod 46 | def from_pretrained(cls, *args, **kwargs): 47 | requires_backends(cls, {1}) 48 | """ 49 | 50 | 51 | DUMMY_FUNCTION = """ 52 | def {0}(*args, **kwargs): 53 | requires_backends({0}, {1}) 54 | """ 55 | 56 | 57 | def find_backend(line): 58 | """Find one (or multiple) backend in a code line of the init.""" 59 | backends = _re_backend.findall(line) 60 | if len(backends) == 0: 61 | return None 62 | 63 | return "_and_".join(backends) 64 | 65 | 66 | def read_init(): 67 | """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" 68 | with open(os.path.join(PATH_TO_DIARIZERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: 69 | lines = f.readlines() 70 | 71 | # Get to the point we do the actual imports for type checking 72 | line_index = 0 73 | backend_specific_objects = {} 74 | # Go through the end of the file 75 | while line_index < len(lines): 76 | # If the line contains is_backend_available, we grab all objects associated with the `else` block 77 | backend = find_backend(lines[line_index]) 78 | if backend is not None: 79 | line_index += 1 80 | objects = [] 81 | # Until we unindent, add backend objects to the list 82 | while line_index < len(lines) and len(lines[line_index]) > 1 and not lines[line_index].startswith("else:"): 83 | line = lines[line_index] 84 | single_line_import_search = _re_single_line_import.search(line) 85 | if single_line_import_search is not None: 86 | objects.extend(single_line_import_search.groups()[0].split(", ")) 87 | elif line.startswith(" " * 4): 88 | objects.append(line[4:-2]) 89 | line_index += 1 90 | 91 | if len(objects) > 0: 92 | backend_specific_objects[backend] = objects 93 | else: 94 | line_index += 1 95 | 96 | return backend_specific_objects 97 | 98 | 99 | def create_dummy_object(name, backend_name): 100 | """Create the code for the dummy object corresponding to `name`.""" 101 | if name.isupper(): 102 | return DUMMY_CONSTANT.format(name) 103 | elif name.islower(): 104 | return DUMMY_FUNCTION.format(name, backend_name) 105 | else: 106 | return DUMMY_CLASS.format(name, backend_name) 107 | 108 | 109 | def create_dummy_files(backend_specific_objects=None): 110 | """Create the content of the dummy files.""" 111 | if backend_specific_objects is None: 112 | backend_specific_objects = read_init() 113 | # For special correspondence backend to module name as used in the function requires_modulename 114 | dummy_files = {} 115 | 116 | for backend, objects in backend_specific_objects.items(): 117 | backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]" 118 | backend_name = backend_name.replace("pyannote", "pyannote.audio") 119 | dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" 120 | dummy_file += "# flake8: noqa\n\n" 121 | dummy_file += "from ..utils import DummyObject, requires_backends\n\n" 122 | dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) 123 | dummy_files[backend] = dummy_file 124 | 125 | return dummy_files 126 | 127 | 128 | def check_dummies(overwrite=False): 129 | """Check if the dummy files are up to date and maybe `overwrite` with the right content.""" 130 | dummy_files = create_dummy_files() 131 | # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py 132 | short_names = {"torch": "pt"} 133 | 134 | # Locate actual dummy modules and read their content. 135 | path = os.path.join(PATH_TO_DIARIZERS, "utils") 136 | dummy_file_paths = { 137 | backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") 138 | for backend in dummy_files.keys() 139 | } 140 | 141 | actual_dummies = {} 142 | for backend, file_path in dummy_file_paths.items(): 143 | if os.path.isfile(file_path): 144 | with open(file_path, "r", encoding="utf-8", newline="\n") as f: 145 | actual_dummies[backend] = f.read() 146 | else: 147 | actual_dummies[backend] = "" 148 | 149 | for backend in dummy_files.keys(): 150 | if dummy_files[backend] != actual_dummies[backend]: 151 | if overwrite: 152 | print( 153 | f"Updating diarizers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " 154 | "__init__ has new objects." 155 | ) 156 | with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: 157 | f.write(dummy_files[backend]) 158 | else: 159 | raise ValueError( 160 | "The main __init__ has objects that are not present in " 161 | f"diarizers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " 162 | "to fix this." 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") 169 | args = parser.parse_args() 170 | 171 | check_dummies(args.fix_and_overwrite) 172 | -------------------------------------------------------------------------------- /src/diarizers/data/preprocess.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/tasks/segmentation/speaker_diarization.py 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..models import SegmentationModel 8 | 9 | 10 | class Preprocess: 11 | """Converts a HF dataset with the following features: 12 | - "audio": Audio feature. 13 | - "speakers": The list of audio speakers, with their order of appearance. 14 | - "timestamps_start": A list of timestamps indicating the start of each speaker segment. 15 | flake8>=3.8.3 16 | - "timestamps_end": A list of timestamps indicating the end of each speaker segment. 17 | to a preprocessed dataset ready to be used with the HF Trainer. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | config, 23 | ): 24 | """Preprocess init method. 25 | Takes as input the dataset to process and the model to perform training with. 26 | The preprocessing is done to fit the hyperparameters of the model. 27 | Args: 28 | input_dataset (dataset): Hugging Face Speaker Diarization dataset 29 | model (SegmentationModel): A SegmentationModel from the diarizers library. 30 | """ 31 | self.chunk_duration = config.chunk_duration 32 | self.max_speakers_per_frame = config.max_speakers_per_frame 33 | self.max_speakers_per_chunk = config.max_speakers_per_chunk 34 | self.min_duration = config.min_duration 35 | self.warm_up = config.warm_up 36 | 37 | self.sample_rate = config.sample_rate 38 | model = SegmentationModel(config).to_pyannote_model() 39 | 40 | # Get the number of frames associated to a chunk: 41 | _, self.num_frames_per_chunk, _ = model(torch.rand((1, int(self.chunk_duration * self.sample_rate)))).shape 42 | 43 | def get_labels_in_file(self, file): 44 | """Get speakers present in file. 45 | Args: 46 | file (_type_): dataset row from the input dataset. 47 | 48 | Returns: 49 | file_labels (list): a list of all speakers in the audio file. 50 | """ 51 | 52 | file_labels = [] 53 | for i in range(len(file["speakers"][0])): 54 | if file["speakers"][0][i] not in file_labels: 55 | file_labels.append(file["speakers"][0][i]) 56 | 57 | return file_labels 58 | 59 | def get_segments_in_file(self, file, labels): 60 | """Get segments present in file. 61 | 62 | Args: 63 | file (_type_): _description_ 64 | labels (_type_): _description_ 65 | 66 | Returns: 67 | annotations (numpy array): _description_ 68 | """ 69 | 70 | file_annotations = [] 71 | 72 | for i in range(len(file["timestamps_start"][0])): 73 | start_segment = file["timestamps_start"][0][i] 74 | end_segment = file["timestamps_end"][0][i] 75 | label = labels.index(file["speakers"][0][i]) 76 | file_annotations.append((start_segment, end_segment, label)) 77 | 78 | dtype = [("start", " start_time)] 113 | 114 | # compute frame resolution: 115 | resolution = self.chunk_duration / self.num_frames_per_chunk 116 | 117 | # discretize chunk annotations at model output resolution 118 | start = np.maximum(chunk_segments["start"], start_time) - start_time 119 | start_idx = np.floor(start / resolution).astype(int) 120 | end = np.minimum(chunk_segments["end"], end_time) - start_time 121 | end_idx = np.ceil(end / resolution).astype(int) 122 | 123 | # get list and number of labels for current scope 124 | labels = list(np.unique(chunk_segments["labels"])) 125 | num_labels = len(labels) 126 | # initial frame-level targets 127 | y = np.zeros((self.num_frames_per_chunk, num_labels), dtype=np.uint8) 128 | 129 | # map labels to indices 130 | mapping = {label: idx for idx, label in enumerate(labels)} 131 | 132 | for start, end, label in zip(start_idx, end_idx, chunk_segments["labels"]): 133 | mapped_label = mapping[label] 134 | y[start:end, mapped_label] = 1 135 | 136 | return waveform, y, labels 137 | 138 | def get_start_positions(self, file, overlap, random=False): 139 | """Get the start positions of the audio_chunks in the input audio file. 140 | 141 | Args: 142 | file (dict): dataset row containing the "audio" feature. 143 | overlap (float, optional): Overlap between successive start positions. 144 | random (bool, optional): Whether or not to randomly select chunks in the audio file. Defaults to False. 145 | 146 | Returns: 147 | start_positions: Numpy array containing the start positions of the audio chunks in file. 148 | """ 149 | 150 | sample_rate = file["audio"][0]["sampling_rate"] 151 | 152 | assert sample_rate == self.sample_rate 153 | 154 | file_duration = len(file["audio"][0]["array"]) / sample_rate 155 | start_positions = np.arange(0, file_duration - self.chunk_duration, self.chunk_duration * (1 - overlap)) 156 | 157 | if random: 158 | nb_samples = int(file_duration / self.chunk_duration) 159 | start_positions = np.random.uniform(0, file_duration, nb_samples) 160 | 161 | return start_positions 162 | 163 | def __call__(self, file, random=False, overlap=0.0): 164 | """Chunk an audio file into short segments of duration self.chunk_duration 165 | 166 | Args: 167 | file (dict): dataset row containing the "audio" feature. 168 | random (bool, optional): Whether or not to randomly select chunks in the audio file. Defaults to False. 169 | overlap (float, optional): Overlap between successive chunks. Defaults to 0.0. 170 | 171 | Returns: 172 | new_batch: new batch containing for each chunk the corresponding waveform, labels and number of speakers. 173 | """ 174 | 175 | new_batch = {"waveforms": [], "labels": [], "nb_speakers": []} 176 | 177 | if random: 178 | start_positions = self.get_start_positions(file, overlap, random=True) 179 | else: 180 | start_positions = self.get_start_positions(file, overlap) 181 | 182 | for start_time in start_positions: 183 | waveform, target, label = self.get_chunk(file, start_time) 184 | 185 | new_batch["waveforms"].append(waveform) 186 | new_batch["labels"].append(target) 187 | new_batch["nb_speakers"].append(label) 188 | 189 | return new_batch 190 | -------------------------------------------------------------------------------- /src/diarizers/data/speaker_diarization.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/hbredin/pyannote-db-callhome/blob/master/parse_transcripts.py 2 | import numpy as np 3 | 4 | from datasets import Audio, Dataset, DatasetDict 5 | 6 | 7 | def get_secs(x): 8 | return x * 4 * 2.0 / 8000 9 | 10 | 11 | def get_start_end(t1, t2): 12 | t1 = get_secs(t1) 13 | t2 = get_secs(t2) 14 | return t1, t2 15 | 16 | 17 | def represent_int(s): 18 | try: 19 | int(s) 20 | return True 21 | except ValueError as e: 22 | return False 23 | 24 | 25 | class SpeakerDiarizationDataset: 26 | """ 27 | Convert a speaker diarization dataset made of