├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── dover_lap ├── __init__.py ├── dover_lap.py ├── libs │ ├── __init__.py │ ├── rttm.py │ ├── turn.py │ ├── uem.py │ └── utils.py └── src │ ├── __init__.py │ ├── doverlap.py │ ├── label_mapping.py │ ├── label_voting.py │ ├── mapping │ ├── __init__.py │ ├── greedy.py │ ├── hungarian.py │ └── map_utils.py │ └── voting │ ├── __init__.py │ └── average.py ├── egs └── ami │ ├── ref_rttm_test │ ├── rttm_dl_test │ ├── rttm_test_rpn │ ├── rttm_test_sc │ └── rttm_test_vb ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.6 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Desh Raj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DOVER-Lap 2 | Official implementation for [DOVER-Lap: A method for combining overlap-aware diarization outputs](https://arxiv.org/pdf/2011.01997.pdf). 3 | 4 | ## Installation 5 | 6 | ```shell 7 | pip install dover-lap 8 | ``` 9 | 10 | ## How to run 11 | 12 | After installation, run 13 | 14 | ```shell 15 | dover-lap [OPTIONS] OUTPUT_RTTM [INPUT_RTTMS]... 16 | ``` 17 | 18 | Example: 19 | 20 | ```shell 21 | dover-lap egs/ami/rttm_dl_test egs/ami/rttm_test_* 22 | ``` 23 | 24 | ## Usage instructions 25 | 26 | ```shell 27 | Usage: dover-lap [OPTIONS] OUTPUT_RTTM [INPUT_RTTMS]... 28 | 29 | Apply the DOVER-Lap algorithm on the input RTTM files. 30 | 31 | Options: 32 | --gaussian-filter-std FLOAT Standard deviation for Gaussian filter 33 | applied before voting. This can help reduce 34 | the effect of outliers in the input RTTMs. 35 | For quick turn-taking, set this to a small 36 | value (e.g. 0.1). 0.5 is a good value for 37 | most cases. Set this to a very small value, 38 | e.g. 0.01, to remove filtering. [default: 39 | 0.5] 40 | 41 | --custom-weight TEXT Weights for input RTTMs 42 | --dover-weight FLOAT DOVER weighting factor [default: 0.1] 43 | --weight-type [rank|custom|norm] 44 | Specify whether to use rank weighting or 45 | provide custom weights [default: rank] 46 | 47 | --voting-method [average] Choose voting method to use: average: use 48 | weighted average to combine input RTTMs 49 | [default: average] 50 | 51 | --second-maximal If this flag is set, run a second iteration 52 | of the maximal matching for greedy label 53 | mapping [default: False] 54 | 55 | --label-mapping [hungarian|greedy] 56 | Choose label mapping algorithm to use 57 | [default: greedy] 58 | 59 | --random-seed INTEGER 60 | -c, --channel INTEGER Use this value for output channel IDs 61 | [default: 1] 62 | 63 | -u, --uem-file PATH UEM file path 64 | --help Show this message and exit. 65 | ``` 66 | 67 | **Note:** 68 | 69 | 1. If `--weight-type custom` is used, then `--custom-weight` must be provided. For example: 70 | 71 | ```shell 72 | dover-lap egs/ami/rttm_dl_test egs/ami/rttm_test_* --weight-type custom --custom-weight '[0.4,0.3,0.3]' 73 | ``` 74 | 75 | 2. `label-mapping` can be set to `greedy` (default) or `hungarian`, which is a modified version of the mapping 76 | technique originally proposed in [DOVER](https://arxiv.org/abs/1909.08090). 77 | 78 | ## Results 79 | 80 | We provide a sample result on the AMI mix-headset test set. The results can be 81 | obtained using [`spyder`](https://github.com/desh2608/spyder), which is automatically 82 | installed with `dover-lap`: 83 | 84 | ```shell 85 | dover-lap egs/ami/rttm_dl_test egs/ami/rttm_test_* 86 | spyder egs/ami/ref_rttm_test egs/ami/rttm_dl_test 87 | ``` 88 | 89 | and similarly for the input hypothesis. The DER results are shown below. 90 | 91 | | | MS | FA | Conf. | DER | 92 | |-----------------------------------|:-----:|:----:|:-----:|:-----:| 93 | | Overlap-aware VB resegmentation | 9.84 | **2.06** | 9.60 | 21.50 | 94 | | Overlap-aware spectral clustering | 11.48 | 2.27 | 9.81 | 23.56 | 95 | | Region Proposal Network | **9.49** | 7.68 | 8.25 | 25.43 | 96 | | DOVER-Lap (Hungarian mapping) | 9.98 | 2.13 | 8.25 | 20.35 | 97 | | DOVER-Lap (Greedy mapping)* | 9.96 | 2.16 | **7.75** | **19.86** | 98 | 99 | _* The Greedy label mapping is exponential in number of inputs (see [this paper](https://arxiv.org/abs/2104.01954))._ 100 | 101 | ## Running time 102 | 103 | The algorithm is implemented in pure Python with NumPy for tensor computations. 104 | The time complexity is expected to increase exponentially with the number of 105 | inputs, but it should be reasonable for combining up to 10 input hypotheses. For 106 | combining more than 10 inputs, we recommend setting `--label-mapping hungarian`. 107 | 108 | For smaller number of inputs (up to 5), the algorithm should take only a few seconds 109 | to run on a laptop. 110 | 111 | ## Combining 2 systems with DOVER-Lap 112 | 113 | DOVER-Lap is meant to be used to combine **more than 2 systems**, since 114 | black-box voting between 2 systems does not make much sense. Still, if 2 systems 115 | are provided as input, we fall back on the Hungarian algorithm for label mapping, 116 | since it is provably optimal for this case. Both the systems are assigned equal 117 | weights, and in case of voting conflicts, the region is assigned to both 118 | labels. This is not the intended use case and will almost certainly lead 119 | to performance degradation. 120 | 121 | ## Citation 122 | 123 | ``` 124 | @article{Raj2021Doverlap, 125 | title={{DOVER-Lap}: A Method for Combining Overlap-aware Diarization Outputs}, 126 | author={D.Raj and P.Garcia and Z.Huang and S.Watanabe and D.Povey and A.Stolcke and S.Khudanpur}, 127 | journal={2021 IEEE Spoken Language Technology Workshop (SLT)}, 128 | year={2021} 129 | } 130 | 131 | @article{Raj2021ReformulatingDL, 132 | title={Reformulating {DOVER-Lap} Label Mapping as a Graph Partitioning Problem}, 133 | author={Desh Raj and S. Khudanpur}, 134 | journal={INTERSPEECH}, 135 | year={2021}, 136 | } 137 | ``` 138 | 139 | ## Contact 140 | 141 | For issues/bug reports, please raise an Issue in this repository, or reach out to me at `draj@cs.jhu.edu`. 142 | -------------------------------------------------------------------------------- /dover_lap/__init__.py: -------------------------------------------------------------------------------- 1 | from .libs import * 2 | from .src import * 3 | -------------------------------------------------------------------------------- /dover_lap/dover_lap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | This is the official implementation for the DOVER-Lap algorithm. It combines 4 | overlap-aware diarization hypotheses to produce an output RTTM. 5 | 6 | Raj, D., García-Perera, L.P., Huang, Z., Watanabe, S., Povey, D., Stolcke, A., & Khudanpur, S. 7 | DOVER-Lap: A Method for Combining Overlap-aware Diarization Outputs. 8 | IEEE Spoken Language Technology Workshop 2021. 9 | 10 | Copyright 2020 Desh Raj (Johns Hopkins University) 11 | """ 12 | import sys 13 | import click 14 | import random 15 | import numpy as np 16 | 17 | from typing import List 18 | 19 | from dover_lap.libs.rttm import load_rttm, write_rttm 20 | from dover_lap.libs.turn import merge_turns, trim_turns, Turn 21 | from dover_lap.libs.uem import load_uem 22 | from dover_lap.libs.utils import ( 23 | error, 24 | info, 25 | warn, 26 | groupby, 27 | command_required_option, 28 | PythonLiteralOption, 29 | ) 30 | 31 | from dover_lap.src.doverlap import DOVERLap 32 | 33 | 34 | def load_rttms(rttm_list: List[str]) -> List[List[Turn]]: 35 | """Loads speaker turns from input RTTMs in a list of turns.""" 36 | turns_list = [] 37 | file_ids = [] 38 | for rttm_fn in sorted(rttm_list): 39 | try: 40 | turns_, _, file_ids_ = load_rttm(rttm_fn) 41 | turns_list.append(turns_) 42 | file_ids += list(file_ids_) 43 | except IOError as e: 44 | error("Invalid RTTM file: %s. %s" % (rttm_fn, e)) 45 | sys.exit(1) 46 | return turns_list 47 | 48 | 49 | @click.argument("input_rttms", nargs=-1, type=click.Path(exists=True)) 50 | @click.argument("output_rttm", nargs=1, type=click.Path()) 51 | @click.option("-u", "--uem-file", type=click.Path(), help="UEM file path") 52 | @click.option( 53 | "-c", 54 | "--channel", 55 | type=int, 56 | default=1, 57 | show_default=True, 58 | help="Use this value for output channel IDs", 59 | ) 60 | @click.option("--random-seed", type=int, default=0) 61 | @click.option( 62 | "--label-mapping", 63 | type=click.Choice(["hungarian", "greedy"]), 64 | default="greedy", 65 | show_default=True, 66 | help="Choose label mapping algorithm to use", 67 | ) 68 | @click.option( 69 | "--second-maximal", 70 | is_flag=True, 71 | default=False, 72 | show_default=True, 73 | help="If this flag is set, run a second iteration of the maximal matching for" 74 | " greedy label mapping", 75 | ) 76 | @click.option( 77 | "--voting-method", 78 | type=click.Choice(["average"]), 79 | default="average", 80 | show_default=True, 81 | help="Choose voting method to use:" 82 | " average: use weighted average to combine input RTTMs", 83 | ) 84 | @click.option( 85 | "--weight-type", 86 | type=click.Choice(["rank", "custom", "norm"]), 87 | default="rank", 88 | help="Specify whether to use rank weighting or provide custom weights", 89 | show_default=True, 90 | ) 91 | @click.option( 92 | "--dover-weight", 93 | type=float, 94 | default=0.1, 95 | help="DOVER weighting factor", 96 | show_default=True, 97 | ) 98 | @click.option( 99 | "--custom-weight", cls=PythonLiteralOption, help="Weights for input RTTMs" 100 | ) 101 | @click.option( 102 | "--gaussian-filter-std", 103 | type=float, 104 | default=0.5, 105 | help="Standard deviation for Gaussian filter applied before voting. This can help" 106 | " reduce the effect of outliers in the input RTTMs. For quick turn-taking, set" 107 | " this to a small value (e.g. 0.1). 0.5 is a good value for most cases. Set this" 108 | " to a very small value, e.g. 0.01, to remove filtering.", 109 | show_default=True, 110 | ) 111 | @click.command( 112 | cls=command_required_option( 113 | "weight_type", {"custom": "custom_weight", "rank": "dover_weight", "norm": None} 114 | ) 115 | ) 116 | def main( 117 | input_rttms: List[click.Path], 118 | output_rttm: click.Path, 119 | uem_file: click.Path, 120 | channel: int, 121 | random_seed: int, 122 | **kwargs, # these are passed directly to combine_turns_list() method 123 | ) -> None: 124 | """Apply the DOVER-Lap algorithm on the input RTTM files.""" 125 | 126 | # Set random seeds globally 127 | random.seed(random_seed) 128 | np.random.seed(random_seed) 129 | 130 | # Load hypothesis speaker turns. 131 | info("Loading speaker turns from input RTTMs...", file=sys.stderr) 132 | turns_list = load_rttms(input_rttms) 133 | 134 | if uem_file is not None: 135 | info("Loading universal evaluation map...", file=sys.stderr) 136 | uem = load_uem(uem_file) 137 | 138 | # Trim turns to UEM scoring regions and merge any that overlap. 139 | info( 140 | "Trimming reference speaker turns to UEM scoring regions...", 141 | file=sys.stderr, 142 | ) 143 | turns_list = [trim_turns(turns, uem) for turns in turns_list] 144 | 145 | info("Merging overlapping speaker turns...", file=sys.stderr) 146 | turns_list = [merge_turns(turns) for turns in turns_list] 147 | 148 | file_to_turns_list = dict() 149 | for turns in turns_list: 150 | for fid, g in groupby(turns, lambda x: x.file_id): 151 | if fid in file_to_turns_list: 152 | file_to_turns_list[fid].append(list(g)) 153 | else: 154 | file_to_turns_list[fid] = [list(g)] 155 | 156 | # Run DOVER-Lap algorithm 157 | file_to_out_turns = dict() 158 | for file_id in file_to_turns_list: 159 | info("Processing file {}..".format(file_id), file=sys.stderr) 160 | turns_list = file_to_turns_list[file_id] 161 | random.shuffle( 162 | turns_list 163 | ) # We shuffle so that the hypothesis order is randomized 164 | file_to_out_turns[file_id] = DOVERLap.combine_turns_list( 165 | turns_list, file_id, **kwargs 166 | ) 167 | 168 | # Write output RTTM file 169 | write_rttm(output_rttm, sum(list(file_to_out_turns.values()), []), channel=channel) 170 | -------------------------------------------------------------------------------- /dover_lap/libs/__init__.py: -------------------------------------------------------------------------------- 1 | from .rttm import * 2 | from .turn import * 3 | from .uem import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /dover_lap/libs/rttm.py: -------------------------------------------------------------------------------- 1 | """Functions for reading/writing RTTM files. 2 | 3 | Taken from https://github.com/nryant/dscore 4 | """ 5 | from .turn import Turn 6 | from .utils import format_float 7 | 8 | 9 | def _parse_rttm_line(line): 10 | line = line.decode("utf-8").strip() 11 | fields = line.split() 12 | if len(fields) < 9: 13 | raise IOError('Number of fields < 9. LINE: "%s"' % line) 14 | file_id = fields[1] 15 | speaker_id = fields[7] 16 | 17 | # Check valid turn onset. 18 | try: 19 | onset = float(fields[3]) 20 | except ValueError: 21 | raise IOError('Turn onset not FLOAT. LINE: "%s"' % line) 22 | if onset < 0: 23 | raise IOError('Turn onset < 0 seconds. LINE: "%s"' % line) 24 | 25 | # Check valid turn duration. 26 | try: 27 | dur = float(fields[4]) 28 | except ValueError: 29 | raise IOError('Turn duration not FLOAT. LINE: "%s"' % line) 30 | if dur <= 0: 31 | raise IOError('Turn duration <= 0 seconds. LINE: "%s"' % line) 32 | 33 | return Turn(onset, dur=dur, speaker_id=speaker_id, file_id=file_id) 34 | 35 | 36 | def load_rttm(rttmf): 37 | """Load speaker turns from RTTM file. 38 | 39 | For a description of the RTTM format, consult Appendix A of the NIST RT-09 40 | evaluation plan. 41 | 42 | Parameters 43 | ---------- 44 | rttmf : str 45 | Path to RTTM file. 46 | 47 | Returns 48 | ------- 49 | turns : list of Turn 50 | Speaker turns. 51 | 52 | speaker_ids : set 53 | Speaker ids present in ``rttmf``. 54 | 55 | file_ids : set 56 | File ids present in ``rttmf``. 57 | 58 | References 59 | ---------- 60 | NIST. (2009). The 2009 (RT-09) Rich Transcription Meeting Recognition 61 | Evaluation Plan. https://web.archive.org/web/20100606041157if_/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf 62 | """ 63 | with open(rttmf, "rb") as f: 64 | turns = [] 65 | speaker_ids = set() 66 | file_ids = set() 67 | for line in f: 68 | if line.startswith(b"SPKR-INFO"): 69 | continue 70 | turn = _parse_rttm_line(line) 71 | turns.append(turn) 72 | speaker_ids.add(turn.speaker_id) 73 | file_ids.add(turn.file_id) 74 | return turns, speaker_ids, file_ids 75 | 76 | 77 | def write_rttm(rttmf, turns, n_digits=3, channel=1): 78 | """Write speaker turns to RTTM file. 79 | 80 | For a description of the RTTM format, consult Appendix A of the NIST RT-09 81 | evaluation plan. 82 | 83 | Parameters 84 | ---------- 85 | rttmf : str 86 | Path to output RTTM file. 87 | 88 | turns : list of Turn 89 | Speaker turns. 90 | 91 | n_digits : int, optional 92 | Number of decimal digits to round to. 93 | (Default: 3) 94 | 95 | channel : int, optional 96 | Output channel 97 | (Default: 1) 98 | 99 | References 100 | ---------- 101 | NIST. (2009). The 2009 (RT-09) Rich Transcription Meeting Recognition 102 | Evaluation Plan. https://web.archive.org/web/20100606041157if_/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf 103 | """ 104 | with open(rttmf, "wb") as f: 105 | for turn in turns: 106 | fields = [ 107 | "SPEAKER", 108 | turn.file_id, 109 | str(channel), 110 | format_float(turn.onset, n_digits), 111 | format_float(turn.dur, n_digits), 112 | "", 113 | "", 114 | str(turn.speaker_id), 115 | "", 116 | "", 117 | ] 118 | line = " ".join(fields) 119 | f.write(line.encode("utf-8")) 120 | f.write(b"\n") 121 | 122 | 123 | def validate_rttm(rttmf): 124 | """Validate RTTM file. 125 | 126 | Parameters 127 | ---------- 128 | rttmf : str 129 | Path to RTTM file. 130 | 131 | Returns 132 | ------- 133 | file_ids : set of str 134 | File ids present in ``rttmf``. 135 | 136 | speaker_ids : set of str 137 | Speaker ids present in ``rttm``. 138 | 139 | error_messages : list of str 140 | Errors encountered in file. 141 | """ 142 | with open(rttmf, "rb") as f: 143 | file_ids = set() 144 | speaker_ids = set() 145 | error_messages = [] 146 | for line in f: 147 | if line.startswith(b"SPKR-INFO"): 148 | continue 149 | try: 150 | turn = _parse_rttm_line(line) 151 | file_ids.add(turn.file_id) 152 | speaker_ids.add(turn.speaker_id) 153 | except IOError as e: 154 | error_messages.append(e.args[0]) 155 | return file_ids, speaker_ids, error_messages 156 | -------------------------------------------------------------------------------- /dover_lap/libs/turn.py: -------------------------------------------------------------------------------- 1 | """Classes for representing speaker turns and interacting with RTTM files. 2 | 3 | Taken from https://github.com/nryant/dscore 4 | """ 5 | from collections import namedtuple 6 | from intervaltree import Interval, IntervalTree 7 | 8 | from .uem import UEM 9 | from .utils import groupby, warn, xor 10 | 11 | 12 | Token = namedtuple("Token", ["type", "timestamp", "speaker", "hyp_index", "weight"]) 13 | 14 | 15 | class Turn(object): 16 | """Speaker turn class. 17 | 18 | A turn represents a segment of audio attributed to a single speaker. 19 | 20 | Parameters 21 | ---------- 22 | onset : float 23 | Onset of turn in seconds from beginning of recording. 24 | 25 | offset : float, optional 26 | Offset of turn in seconds from beginning of recording. If None, then 27 | computed from ``onset`` and ``dur``. 28 | (Default: None) 29 | 30 | dur : float, optional 31 | Duration of turn in seconds. If None, then computed from ``onset`` and 32 | ``offset``. 33 | (Default: None) 34 | 35 | speaker_id : str, optional 36 | Speaker id. 37 | (Default: None) 38 | 39 | file_id : str, optional 40 | File id. 41 | (Default: none) 42 | """ 43 | 44 | def __init__( 45 | self, onset, offset=None, dur=None, speaker_id=None, file_id=None, weight=None 46 | ): 47 | if not xor(offset is None, dur is None): 48 | raise ValueError("Exactly one of offset or dur must be given") 49 | if onset < 0: 50 | raise ValueError("Turn onset must be >= 0 seconds") 51 | if offset: 52 | dur = offset - onset 53 | if dur <= 0: 54 | raise ValueError("Turn duration must be > 0 seconds") 55 | if not offset: 56 | offset = onset + dur 57 | self.onset = onset 58 | self.offset = offset 59 | self.dur = dur 60 | self.speaker_id = speaker_id 61 | self.file_id = file_id 62 | self.weight = weight 63 | 64 | def __eq__(self, other): 65 | if isinstance(other, self.__class__): 66 | return self.__dict__ == other.__dict__ 67 | return False 68 | 69 | def __ne__(self, other): 70 | return not self.__eq__(other) 71 | 72 | def __hash__(self): 73 | return hash((self.onset, self.offset, self.dur, self.file_id, self.speaker_id)) 74 | 75 | def __str__(self): 76 | return "FILE: %s, SPEAKER: %s, ONSET: %f, OFFSET: %f, DUR: %f" % ( 77 | self.file_id, 78 | self.speaker_id, 79 | self.onset, 80 | self.offset, 81 | self.dur, 82 | ) 83 | 84 | def __repr__(self): 85 | speaker_id = "'%s'" % self.speaker_id if self.speaker_id is not None else None 86 | file_id = "'%s'" % self.file_id if self.file_id is not None else None 87 | return "Turn(%f, %f, None, %s, %s)" % ( 88 | self.onset, 89 | self.offset, 90 | speaker_id, 91 | file_id, 92 | ) 93 | 94 | 95 | def merge_turns(turns): 96 | """Merge overlapping turns by same speaker within each file.""" 97 | # Merge separately within each file and for each speaker. 98 | new_turns = [] 99 | for (file_id, speaker_id), speaker_turns in groupby( 100 | turns, lambda x: (x.file_id, x.speaker_id) 101 | ): 102 | speaker_turns = list(speaker_turns) 103 | speaker_it = IntervalTree.from_tuples( 104 | [(turn.onset, turn.offset) for turn in speaker_turns] 105 | ) 106 | n_turns_pre = len(speaker_it) 107 | speaker_it.merge_overlaps(strict=False) 108 | n_turns_post = len(speaker_it) 109 | if n_turns_post < n_turns_pre: 110 | speaker_turns = [] 111 | for intrvl in speaker_it: 112 | speaker_turns.append( 113 | Turn( 114 | intrvl.begin, intrvl.end, speaker_id=speaker_id, file_id=file_id 115 | ) 116 | ) 117 | speaker_turns = sorted(speaker_turns, key=lambda x: (x.onset, x.offset)) 118 | # warn('Merging overlapping speaker turns. ' 119 | # 'FILE: %s, SPEAKER: %s' % (file_id, speaker_id)) 120 | new_turns.extend(speaker_turns) 121 | return new_turns 122 | 123 | 124 | def chop_tree(tree, onset, offset): 125 | """Trim Intervals so that none overlap [``onset``, ``offset``]. 126 | 127 | Intervals contained entirely within the chopped region are removed and 128 | those overlapping, but not contained are trimmed back. Differs from 129 | ``IntervalTree.chop`` in that it keeps track of which intervals in the 130 | tree were affected. 131 | 132 | This is an inplace operation. 133 | 134 | Parameters 135 | ---------- 136 | tree : IntervalTree 137 | Interval tree. 138 | 139 | onset : float 140 | Onset of chopped region. 141 | 142 | offset : float 143 | Offset of chopped region. 144 | 145 | Returns 146 | ------- 147 | affected_intervals : set of Interval 148 | Intervals from ``tree`` that overlap chopped region. 149 | """ 150 | overlapped_intervals = set() # Intervals overlapping chopped region. 151 | insertions = set() # Intervals to add. 152 | 153 | # Identify intervals contained entirely within [onset, offset]. 154 | overlapped_intervals.update(tree.envelop(onset, offset)) 155 | 156 | # Identify all other intervals overlapping [onset, offset]. These belong 157 | # to two classes: 158 | # - right overlap -- interval.begin < onset 159 | # - left overlap -- offset < interval.end 160 | for intrvl in tree.at(onset): 161 | if intrvl.begin >= onset: 162 | continue 163 | overlap_dur = intrvl.end - onset 164 | if not overlap_dur: 165 | continue 166 | overlapped_intervals.add(intrvl) 167 | insertions.add(Interval(intrvl.begin, onset, intrvl.data)) 168 | for intrvl in tree.at(offset): 169 | if intrvl.end <= offset: 170 | continue 171 | overlap_dur = offset - intrvl.begin 172 | if not overlap_dur: 173 | continue 174 | overlapped_intervals.add(intrvl) 175 | insertions.add(Interval(offset, intrvl.end, intrvl.data)) 176 | 177 | # Update tree. 178 | for intrvl in overlapped_intervals: 179 | tree.discard(intrvl) 180 | tree.update(insertions) 181 | 182 | return overlapped_intervals 183 | 184 | 185 | MAX_SESSION_DUR = 1e6 # Maximum duration (seconds) of session. Any outlandishly 186 | # high number will do. 187 | 188 | 189 | def trim_turns(turns, uem=None, score_onset=None, score_offset=None): 190 | """Trim turns to scoring regions defined in UEM. 191 | 192 | Parameters 193 | ---------- 194 | turns : list of Turn 195 | Speaker turns. 196 | 197 | uem : UEM, optional 198 | Un-partitioned evaluation map. 199 | (Default: None) 200 | 201 | score_onset : float, optional 202 | Onset of scoring region in seconds from beginning of file. Only valid 203 | if ``uem=None``. 204 | (Default: None) 205 | 206 | score_offset : float, optional 207 | Offset of scoring region in seconds from beginning of file. Only 208 | valid if ``uem=None``. 209 | (Default: None) 210 | 211 | Returns 212 | ------- 213 | trimmed_turns : list of Turn 214 | Trimmed turns. 215 | """ 216 | # Validate arguments. 217 | if uem is not None: 218 | if not (score_onset is None and score_offset is None): 219 | raise ValueError( 220 | "Either uem or score_onset and score_offset must " "be specified." 221 | ) 222 | else: 223 | if score_onset is None or score_offset is None: 224 | raise ValueError( 225 | "Either uem or score_onset and score_offset must " "be specified." 226 | ) 227 | if score_onset < 0: 228 | raise ValueError("Scoring region onset must be >= 0 seconds") 229 | if score_offset <= score_onset: 230 | raise ValueError("Scoring region duration must be > 0 seconds") 231 | 232 | # If no UEM provided, set each file to have same scoring region: 233 | # (score_onset, score_offset). 234 | if uem is None: 235 | file_ids = set([turn.file_id for turn in turns]) 236 | uem = UEM({fid: [(score_onset, score_offset)] for fid in file_ids}) 237 | 238 | # Trim turns to scoring regions. 239 | new_turns = [] 240 | for file_id, file_turns in groupby(turns, lambda x: x.file_id): 241 | if file_id not in uem: 242 | for turn in file_turns: 243 | warn("Skipping turn from file not in UEM. TURN: %s" % turn) 244 | continue 245 | 246 | # Remove overlaps with no score regions. 247 | noscore_tree = IntervalTree.from_tuples([(0.0, MAX_SESSION_DUR)]) 248 | for score_onset, score_offset in uem[file_id]: 249 | noscore_tree.chop(score_onset, score_offset) 250 | turns_tree = IntervalTree.from_tuples( 251 | (turn.onset, turn.offset, turn) for turn in file_turns 252 | ) 253 | overlapped_turns = set() # Turns found to overlap a no score region. 254 | for noscore_intrvl in noscore_tree: 255 | overlapped_intrvls = chop_tree( 256 | turns_tree, noscore_intrvl.begin, noscore_intrvl.end 257 | ) 258 | overlapped_turns.update([intrvl.data for intrvl in overlapped_intrvls]) 259 | 260 | # Convert interval tree to turns. 261 | for intrvl in turns_tree: 262 | orig_turn = intrvl.data 263 | new_turns.append( 264 | Turn( 265 | intrvl.begin, 266 | intrvl.end, 267 | speaker_id=orig_turn.speaker_id, 268 | file_id=orig_turn.file_id, 269 | ) 270 | ) 271 | 272 | # Report any overlapping turns to STDERR. 273 | for turn in sorted(overlapped_turns, key=lambda x: (x.onset, x.offset)): 274 | warn("Truncating turn overlapping non-scoring region. TURN: %s" % turn) 275 | 276 | return new_turns 277 | -------------------------------------------------------------------------------- /dover_lap/libs/uem.py: -------------------------------------------------------------------------------- 1 | """Functions for reading/writing and manipulating NIST un-partitioned 2 | evaluation maps. 3 | 4 | An un-partitioned evaluation map (UEM) specifies the time regions within each 5 | file that will be scored. 6 | 7 | Taken from https://github.com/nryant/dscore 8 | """ 9 | from collections import defaultdict 10 | from collections.abc import MutableMapping 11 | 12 | import os 13 | 14 | from intervaltree import IntervalTree 15 | 16 | 17 | class UEM(MutableMapping): 18 | """Un-partitioned evaluation map (UEM). 19 | 20 | A UEM defines a mapping from file ids to scoring regions. 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | super(UEM, self).__init__() 25 | self.update(*args, **kwargs) 26 | 27 | def __setitem__(self, fid, score_regions): 28 | # Validate types. Expects sequence of (onset, offset) pairs. 29 | invalid_type_msg = "Expected sequence of pairs. Received: %r (%s)." % ( 30 | score_regions, 31 | type(score_regions), 32 | ) 33 | try: 34 | score_regions = [tuple(region) for region in score_regions] 35 | except TypeError: 36 | raise TypeError(invalid_type_msg) 37 | for score_region in score_regions: 38 | if len(score_region) != 2: 39 | raise TypeError(invalid_type_msg) 40 | 41 | # Validate that the (onset, offset) pairs are valid: no negative 42 | # timestamps or negative durations. 43 | def _convert_to_float(score_region): 44 | onset, offset = score_region 45 | try: 46 | onset = float(onset) 47 | offset = float(offset) 48 | except ValueError: 49 | raise ValueError( 50 | "Could not convert interval onset/offset to float: %s" 51 | % repr(score_region) 52 | ) 53 | if onset >= offset or onset < 0: 54 | raise ValueError( 55 | 'Invalid interval (%.3f, %.3f) for file "%s".' 56 | % (onset, offset, fid) 57 | ) 58 | return onset, offset 59 | 60 | score_regions = [_convert_to_float(region) for region in score_regions] 61 | 62 | # Merge overlaps. Use of intervaltree Incurs some additional overhead, 63 | # but pretty small compared to actual scoring. 64 | tree = IntervalTree.from_tuples(score_regions) 65 | tree.merge_overlaps() 66 | score_regions = [(intrvl.begin, intrvl.end) for intrvl in tree] 67 | 68 | self.__dict__[fid] = score_regions 69 | 70 | def __getitem__(self, key): 71 | return self.__dict__[key] 72 | 73 | def __delitem__(self, key): 74 | del self.__dict__[key] 75 | 76 | def __iter__(self): 77 | return iter(self.__dict__) 78 | 79 | def __len__(self): 80 | return len(self.__dict__) 81 | 82 | def __str__(self): 83 | return str(self.__dict__) 84 | 85 | def __repr__(self): 86 | return "{}, UEM({})".format(super(UEM, self).__repr__(), self.__dict__) 87 | 88 | 89 | def load_uem(uemf): 90 | """Load un-partitioned evaluation map from file in NIST format. 91 | 92 | The un-partitioned evaluation map (UEM) file format contains 93 | one record per line, each line consisting of NN space-delimited 94 | fields: 95 | 96 | - file id -- file id 97 | - channel -- channel (1-indexed) 98 | - onset -- onset of evaluation region in seconds from beginning of file 99 | - offset -- offset of evaluation region in seconds from beginning of 100 | file 101 | 102 | Lines beginning with semicolons are regarded as comments and ignored. 103 | 104 | Parameters 105 | ---------- 106 | uemf : str 107 | Path to UEM file. 108 | 109 | Returns 110 | ------- 111 | uem : UEM 112 | Evaluation map. 113 | """ 114 | with open(uemf, "rb") as f: 115 | fid_to_score_regions = defaultdict(list) 116 | for line in f: 117 | if line.startswith(b";"): 118 | continue 119 | fields = line.decode("utf-8").strip().split() 120 | file_id = os.path.splitext(fields[0])[0] 121 | onset = float(fields[2]) 122 | offset = float(fields[3]) 123 | fid_to_score_regions[file_id].append((onset, offset)) 124 | return UEM(fid_to_score_regions.items()) 125 | -------------------------------------------------------------------------------- /dover_lap/libs/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions. 2 | 3 | Taken from https://github.com/nryant/dscore 4 | """ 5 | import itertools 6 | import sys 7 | import click 8 | import ast 9 | 10 | 11 | def error(msg, file=sys.stderr): 12 | """Log error message ``msg`` to stderr.""" 13 | msg = "ERROR: %s" % msg 14 | print(msg, file=file) 15 | 16 | 17 | def info(msg, print_level=False, file=sys.stdout): 18 | """Log info message ``msg`` to stdout.""" 19 | if print_level: 20 | msg = "INFO: %s" % msg 21 | print(msg, file=file) 22 | 23 | 24 | def warn(msg, file=sys.stderr): 25 | """Log warning message ``msg`` to stderr.""" 26 | msg = "WARNING: %s" % msg 27 | print(msg, file=file) 28 | 29 | 30 | def xor(x, y): 31 | """Return truth value of ``x`` XOR ``y``.""" 32 | return bool(x) != bool(y) 33 | 34 | 35 | def format_float(x, n_digits=3): 36 | """Format floating point number for output as string. 37 | 38 | Parameters 39 | ---------- 40 | x : float 41 | Number. 42 | 43 | n_digits : int, optional 44 | Number of decimal digits to round to. 45 | (Default: 3) 46 | 47 | Returns 48 | ------- 49 | s : str 50 | Formatted string. 51 | """ 52 | fmt_str = "%%.%df" % n_digits 53 | return fmt_str % round(x, n_digits) 54 | 55 | 56 | def clip(x, lower, upper): 57 | """Clip ``x`` to [``lower``, ``upper``].""" 58 | return min(max(x, lower), upper) 59 | 60 | 61 | def groupby(iterable, keyfunc): 62 | """Wrapper around ``itertools.groupby`` which sorts data first.""" 63 | iterable = sorted(iterable, key=keyfunc) 64 | for key, group in itertools.groupby(iterable, keyfunc): 65 | yield key, group 66 | 67 | 68 | # If an option is selected, other options become required 69 | def command_required_option(require_name, require_map): 70 | class CommandOptionRequiredClass(click.Command): 71 | def invoke(self, ctx): 72 | require = ctx.params[require_name] 73 | if require not in require_map: 74 | raise click.ClickException( 75 | "Unexpected value for --'{}': {}".format(require_name, require) 76 | ) 77 | if ( 78 | require_map[require] is not None 79 | and ctx.params[require_map[require].lower()] is None 80 | ): 81 | raise click.ClickException( 82 | "With {}={} must specify option --{}".format( 83 | require_name, require, require_map[require] 84 | ) 85 | ) 86 | super(CommandOptionRequiredClass, self).invoke(ctx) 87 | 88 | return CommandOptionRequiredClass 89 | 90 | 91 | # Class to accept list of arguments as Click option 92 | class PythonLiteralOption(click.Option): 93 | def type_cast_value(self, ctx, value): 94 | try: 95 | return ast.literal_eval(value) 96 | except: 97 | return None 98 | -------------------------------------------------------------------------------- /dover_lap/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .doverlap import DOVERLap 2 | from .label_mapping import LabelMapping 3 | from .label_voting import LabelVoting 4 | from .mapping import * 5 | -------------------------------------------------------------------------------- /dover_lap/src/doverlap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import List, Union, Optional 4 | 5 | from dover_lap.libs.turn import Turn, merge_turns 6 | from dover_lap.src.label_mapping import LabelMapping 7 | from dover_lap.src.label_voting import LabelVoting 8 | 9 | 10 | class DOVERLap: 11 | @classmethod 12 | def combine_turns_list( 13 | cls, 14 | turns_list: List[List[Turn]], 15 | file_id: str, 16 | label_mapping: Optional[str] = "greedy", 17 | second_maximal: Optional[bool] = False, 18 | voting_method: Optional[str] = "average", 19 | weight_type: Optional[str] = "rank", 20 | dover_weight: Optional[float] = 0.1, 21 | custom_weight: Optional[List[str]] = None, 22 | gaussian_filter_std: float = 0.01, 23 | ) -> List[List[Turn]]: 24 | 25 | # Label mapping stage 26 | mapped_turns_list, weights = LabelMapping.get_mapped_turns_list( 27 | turns_list, 28 | file_id, 29 | method=label_mapping, 30 | second_maximal=second_maximal, 31 | ) 32 | 33 | # Compute weights based on rank 34 | if weight_type == "rank": 35 | ranks = cls.__get_ranks(weights) 36 | weights = cls.__compute_weights(ranks, dover_weight) 37 | elif weight_type == "custom": 38 | assert isinstance(custom_weight, list) 39 | weights = np.array([float(x) for x in custom_weight]) 40 | elif weight_type == "norm": 41 | weights /= np.linalg.norm(weights, ord=1) # use normalized weights 42 | 43 | # Label voting stage 44 | combined_turns_list = LabelVoting.get_combined_turns( 45 | mapped_turns_list, 46 | file_id, 47 | voting_method, 48 | weights, 49 | gaussian_filter_std=gaussian_filter_std, 50 | ) 51 | # Merge consecutive turns with the same label 52 | combined_turns_list = merge_turns(combined_turns_list) 53 | return combined_turns_list 54 | 55 | def __get_ranks(weights: np.array) -> np.array: 56 | 57 | weights /= np.linalg.norm(weights, ord=1) 58 | temp = weights.argsort() 59 | ranks = np.empty_like(temp) 60 | ranks[temp] = np.arange(len(weights)) + 1 61 | return ranks 62 | 63 | def __compute_weights(ranks: np.array, weight: float) -> np.array: 64 | 65 | out_weights = 1 / np.power(ranks, weight) 66 | out_weights /= np.linalg.norm(out_weights, ord=1) 67 | return out_weights 68 | -------------------------------------------------------------------------------- /dover_lap/src/label_mapping.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | 3 | from dover_lap.libs.utils import groupby 4 | from dover_lap.libs.turn import Turn 5 | 6 | from dover_lap.src.mapping import HungarianMap, GreedyMap 7 | 8 | 9 | class LabelMapping: 10 | @classmethod 11 | def get_mapped_turns_list( 12 | cls, 13 | turns_list: List[List[Turn]], 14 | file_id: str, 15 | method: Optional[str] = "greedy", 16 | second_maximal: Optional[bool] = False, 17 | ) -> List[List[Turn]]: 18 | """ 19 | This function takes turns list from all RTTMs and applies an n-dimensional 20 | matching approximation algorithm to map all to a common label space. 21 | """ 22 | 23 | if (len(turns_list) == 2) or (method == "hungarian"): 24 | # We replace the original turns list with one sorted by average DER 25 | hungarian_map = HungarianMap() 26 | label_mapping, weights = hungarian_map.compute_mapping(turns_list) 27 | turns_list = hungarian_map.sorted_turns_list 28 | 29 | elif method == "greedy": 30 | greedy_map = GreedyMap(second_maximal=second_maximal) 31 | label_mapping, weights = greedy_map.compute_mapping(turns_list) 32 | 33 | # Get mapped speaker labels using the mapping 34 | mapped_turns_list = [] 35 | for i, turns in enumerate(turns_list): 36 | spk_groups = { 37 | key: list(group) 38 | for key, group in groupby(turns, lambda x: x.speaker_id) 39 | } 40 | mapped_turns = [] 41 | for spk_id in spk_groups.keys(): 42 | new_spk_id = label_mapping[(i, spk_id)] 43 | for turn in spk_groups[spk_id]: 44 | mapped_turns.append( 45 | Turn( 46 | turn.onset, 47 | turn.offset, 48 | speaker_id=new_spk_id, 49 | file_id=file_id, 50 | ) 51 | ) 52 | mapped_turns_list.append(mapped_turns) 53 | 54 | return mapped_turns_list, weights 55 | -------------------------------------------------------------------------------- /dover_lap/src/label_voting.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | from collections import defaultdict 5 | from typing import List, Optional, Tuple 6 | 7 | from dover_lap.libs.turn import Token, Turn 8 | from dover_lap.src.voting import WeightedAverageVoting 9 | 10 | 11 | class LabelVoting: 12 | EPS = 1e-3 # to avoid float equality check errors 13 | 14 | @classmethod 15 | def get_combined_turns( 16 | cls, 17 | turns_list: List[List[Turn]], 18 | file_id: str, 19 | voting_method: Optional[str] = "average", 20 | weights: Optional[np.array] = None, 21 | gaussian_filter_std: float = 0.01, 22 | ) -> List[Turn]: 23 | """ 24 | Implements combination using the DOVER-Lap label voting method. 25 | :param turns_list, list of mapped speaker turns (from each hypothesis) 26 | :param file_id, name of the file (recording/session) 27 | :param voting_method, which method to use for combining labels. Options: 28 | - `average`: use weighted average of labels 29 | - `hmm`: use a HMM-based voting method 30 | :param weights, hypothesis weights to use for rank weighting 31 | :param gaussian_filter_std, standard deviation of Gaussian filter to apply 32 | :return: combined turns 33 | """ 34 | regions, start_end = cls.__get_regions(turns_list, weights) 35 | 36 | if voting_method == "average": 37 | voter = WeightedAverageVoting(gaussian_filter_std=gaussian_filter_std) 38 | else: 39 | raise ValueError("Unknown voting method: {}".format(voting_method)) 40 | 41 | combined_turns_list = voter.get_combined_turns(regions, start_end, file_id) 42 | 43 | return combined_turns_list 44 | 45 | @classmethod 46 | def __get_regions( 47 | cls, turns_list: List[List[Turn]], weights: Optional[np.array] = None 48 | ) -> List[Tuple[float, float, List[Tuple[int, float]]]]: 49 | """ 50 | Returns homogeneous time regions of the input. 51 | """ 52 | # Map speaker ids to consecutive integers (0, 1, 2, ...) 53 | spk_index = {} 54 | for turns in turns_list: 55 | for turn in turns: 56 | if turn.speaker_id not in spk_index: 57 | spk_index[turn.speaker_id] = len(spk_index) 58 | 59 | # Add weights to turns, and update speaker id. New speaker id is a tuple of (hyp_index, spk_index) 60 | if weights is None: 61 | weights = np.array([1] * len(turns_list)) 62 | weighted_turns_list = [] 63 | 64 | for i, (turns, weight) in enumerate(zip(turns_list, weights)): 65 | weighted_turns = [ 66 | Turn( 67 | turn.onset, 68 | offset=turn.offset, 69 | speaker_id=(i, spk_index[turn.speaker_id]), 70 | file_id=turn.file_id, 71 | weight=weight, 72 | ) 73 | for turn in turns 74 | ] 75 | weighted_turns_list.append(weighted_turns) 76 | all_turns = [turn for turns in weighted_turns_list for turn in turns] 77 | 78 | tokens = [] 79 | for turn in all_turns: 80 | # Name is 'START' (not 'BEG') so that 'END' tokens come first for same timestamp 81 | tokens.append( 82 | Token( 83 | type="START", 84 | timestamp=turn.onset, 85 | hyp_index=turn.speaker_id[0], 86 | speaker=turn.speaker_id[1], 87 | weight=turn.weight, 88 | ) 89 | ) 90 | tokens.append( 91 | Token( 92 | type="END", 93 | timestamp=turn.offset, 94 | hyp_index=turn.speaker_id[0], 95 | speaker=turn.speaker_id[1], 96 | weight=turn.weight, 97 | ) 98 | ) 99 | 100 | sorted_tokens = sorted(tokens, key=lambda x: (x.timestamp, x.type)) 101 | 102 | regions_list = [] 103 | region_start = sorted_tokens[0].timestamp 104 | # We also maintain a running dictionary of speakers and their weights contributed 105 | # by each hypothesis. These weights are stored as a list indexed by hyp_index. 106 | running_speakers_dict = defaultdict( 107 | lambda: [0.0 for _ in range(len(turns_list))] 108 | ) 109 | running_speakers_dict[sorted_tokens[0].speaker][ 110 | sorted_tokens[0].hyp_index 111 | ] += sorted_tokens[0].weight 112 | 113 | for token in sorted_tokens[1:]: 114 | if token.timestamp - region_start > cls.EPS: 115 | running_speakers = [] 116 | for k, v in running_speakers_dict.items(): 117 | if sum(v) > cls.EPS: 118 | running_speakers.append((k, v)) 119 | if len(running_speakers) > 0: 120 | regions_list.append( 121 | (region_start, token.timestamp, copy.deepcopy(running_speakers)) 122 | ) 123 | # Update the weights list for the current speaker 124 | weights = running_speakers_dict[token.speaker] 125 | if token.type == "START": 126 | weights[token.hyp_index] += token.weight 127 | else: 128 | weights[token.hyp_index] -= token.weight 129 | running_speakers_dict[token.speaker] = weights 130 | 131 | region_start = token.timestamp 132 | 133 | # Build regions matrix and start_end matrix. Regions matrix is of shape 134 | # (num_regions, num_speakers, num_hypotheses). start_end matrix is of shape 135 | # (num_regions, 2) and contains the start and end times of each region. 136 | regions = np.zeros( 137 | (len(regions_list), len(spk_index), len(turns_list)), dtype=np.float32 138 | ) 139 | start_end = np.zeros((len(regions_list), 2), dtype=np.float32) 140 | for i, region in enumerate(regions_list): 141 | start_end[i, 0] = region[0] 142 | start_end[i, 1] = region[1] 143 | for spk, weights in region[2]: 144 | regions[i, spk] = weights 145 | 146 | return regions, start_end 147 | -------------------------------------------------------------------------------- /dover_lap/src/mapping/__init__.py: -------------------------------------------------------------------------------- 1 | from .hungarian import HungarianMap 2 | from .greedy import GreedyMap 3 | -------------------------------------------------------------------------------- /dover_lap/src/mapping/greedy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | from typing import List, Dict, Tuple, Optional 5 | 6 | from dover_lap.libs.utils import groupby, info 7 | from dover_lap.libs.turn import Turn 8 | from .map_utils import * 9 | 10 | 11 | class GreedyMap: 12 | def __init__(self, second_maximal: Optional[bool] = False) -> None: 13 | self.second_maximal = second_maximal 14 | 15 | def compute_mapping( 16 | self, 17 | turns_list: List[List[Turn]], 18 | ) -> Tuple[Dict[Tuple[int, str], int], np.ndarray]: 19 | """ 20 | Use the DOVER-Lap greedy label mapping algorithm. Returns a map from 21 | old labels to new, and the weights for the hypotheses. 22 | """ 23 | self.turns_list = turns_list 24 | N = len(self.turns_list) 25 | cost_tensor, pairwise_costs = self.compute_cost_tensor(turns_list) 26 | 27 | # The weight of each hypothesis is computed by computing its total 28 | # overlap with all other hypotheses 29 | weights = np.array([0] * N, dtype=float) 30 | for i in range(N): 31 | cur_pairwise_costs = [ 32 | np.squeeze(x) for x in pairwise_costs.values() if x.shape[i] != 1 33 | ] 34 | weights[i] = -1 * sum([np.sum(x) for x in cur_pairwise_costs]) 35 | 36 | label_mapping = self._apply_maximal_matching( 37 | cost_tensor, 38 | get_speaker_keys(turns_list), 39 | ) 40 | return (label_mapping, weights) 41 | 42 | def compute_cost_tensor(self, turns_list: List[List[Turn]]) -> np.ndarray: 43 | 44 | N = len(turns_list) 45 | k = int((N * (N - 1) / 2)) 46 | pairwise_costs = {} 47 | 48 | has_single_speaker = False 49 | 50 | for i, ref_turns in enumerate(turns_list): 51 | for j, sys_turns in enumerate(turns_list): 52 | if j <= i: 53 | continue 54 | cost = [] 55 | ref_groups = { 56 | key: list(group) 57 | for key, group in groupby(ref_turns, lambda x: x.speaker_id) 58 | } 59 | sys_groups = { 60 | key: list(group) 61 | for key, group in groupby(sys_turns, lambda x: x.speaker_id) 62 | } 63 | 64 | if len(ref_groups.keys()) == 1 or len(sys_groups.keys()) == 1: 65 | has_single_speaker = True 66 | 67 | for ref_spk_id in sorted(ref_groups.keys()): 68 | cur_row = [] 69 | ref_spk_turns = ref_groups[ref_spk_id] 70 | for sys_spk_id in sorted(sys_groups.keys()): 71 | sys_spk_turns = sys_groups[sys_spk_id] 72 | total_overlap = compute_spk_overlap( 73 | ref_spk_turns, sys_spk_turns 74 | ) 75 | cur_row.append(-1 * total_overlap) 76 | cost.append(cur_row) 77 | 78 | new_axis = list(range(N)) 79 | new_axis.remove(i) 80 | new_axis.remove(j) 81 | 82 | # The expand_dims is for easy broadcasting 83 | pairwise_costs[(i, j)] = np.expand_dims( 84 | np.array(cost), axis=tuple(k for k in new_axis) 85 | ) 86 | 87 | if has_single_speaker: 88 | # iterate and add since numpy cannot broadcast with 2 dummy dimensions 89 | vals = list(pairwise_costs.values()) 90 | cost_tensor = vals[0] 91 | for val in vals[1:]: 92 | cost_tensor = np.add(cost_tensor, val) 93 | else: 94 | # otherwise use broadcasting 95 | cost_tensor = np.sum(np.fromiter(pairwise_costs.values(), dtype=object)) 96 | return cost_tensor, pairwise_costs 97 | 98 | def _apply_maximal_matching( 99 | self, 100 | cost_tensor: np.ndarray, 101 | speakers_dict: Dict[Tuple[int, int], str], 102 | ) -> List[List[Turn]]: 103 | 104 | # Sort the cost tensor 105 | sorted_idx = np.transpose( 106 | np.unravel_index(np.argsort(cost_tensor, axis=None), cost_tensor.shape) 107 | ) 108 | 109 | # Get the maximal matching using an approximation algorithm 110 | M = [] 111 | remaining_idx = { 112 | i: list(range(cost_tensor.shape[i])) for i in range(len(cost_tensor.shape)) 113 | } 114 | 115 | iter = 1 116 | while len(remaining_idx.keys()) != 0: 117 | info( 118 | "Iteration {}: {} labels left to be mapped".format( 119 | iter, 120 | sum([len(v) for v in remaining_idx.values()]), 121 | ), 122 | file=sys.stderr, 123 | ) 124 | sorted_idx_filtered = self._filter_sorted_index_list( 125 | sorted_idx, remaining_idx 126 | ) 127 | 128 | # find initial maximal matching 129 | M_cur = [] 130 | for idx in sorted_idx_filtered: 131 | if not self._contradicts(M_cur, idx): 132 | M_cur.append(idx) 133 | 134 | if self.second_maximal: 135 | # find second maximal matching 136 | change = True 137 | while change: 138 | change = False 139 | for idx in list(M_cur): 140 | M_cur.remove(idx) 141 | M_r = self._find_remaining_maximal_matching( 142 | M_cur, sorted_idx_filtered 143 | ) 144 | if len(M_r) > 1: 145 | M_cur = M_cur + M_r 146 | change = True 147 | else: 148 | M_cur.append(idx) 149 | 150 | for idx in M_cur: 151 | for i, j in enumerate(idx): 152 | if i in remaining_idx and j in remaining_idx[i]: 153 | remaining_idx[i].remove(j) 154 | 155 | for i in list(remaining_idx.keys()): 156 | if len(remaining_idx[i]) == 0: 157 | del remaining_idx[i] 158 | 159 | M += M_cur 160 | iter += 1 161 | 162 | label_mapping = {} 163 | for k, idx_tuple in enumerate(M): 164 | # For each speaker j in hypothesis i, assign new label k 165 | for i, j in enumerate(idx_tuple): 166 | old_spk_id = speakers_dict[(i, j)] 167 | if (i, old_spk_id) not in label_mapping: 168 | label_mapping[(i, old_spk_id)] = k 169 | 170 | return label_mapping 171 | 172 | def _find_remaining_maximal_matching( 173 | self, M: List[Dict[int, int]], idx_list: List[Tuple[int, int]] 174 | ) -> List[Tuple[int, int]]: 175 | """ 176 | Given a list of index tuples and a matching M, find a maximal 177 | matching on the "remaining" list, i.e., using only those index 178 | tuples which are not present in the original matching. 179 | """ 180 | S_r = [] 181 | for idx in list(idx_list): 182 | if not self._contradicts(M, idx): 183 | S_r.append(idx) 184 | 185 | M_r = [] 186 | for idx in S_r: 187 | if not self._contradicts(M_r, idx): 188 | M_r.append(idx) 189 | 190 | return M_r 191 | 192 | def _filter_sorted_index_list( 193 | self, sorted_idx: List[np.ndarray], remaining_idx: List[Tuple[int, int]] 194 | ) -> List[np.ndarray]: 195 | """ 196 | Filter the sorted list of index tuples to only retain tuples for which 197 | at least one element is in the remaining list. 198 | """ 199 | sorted_idx_filtered = [] 200 | for idx_tuple in sorted_idx: 201 | remaining = False 202 | for i, j in enumerate(idx_tuple): 203 | if i in remaining_idx and j in remaining_idx[i]: 204 | remaining = True 205 | break 206 | if remaining: 207 | sorted_idx_filtered.append(idx_tuple) 208 | return sorted_idx_filtered 209 | 210 | def _contradicts(self, M: List[Dict[int, int]], idx_tuple: List[int]) -> bool: 211 | """ 212 | Check if an index tuple contradicts a matching, i.e. return True if 213 | any index in the tuple is already present in the matching. 214 | """ 215 | for i, index in enumerate(idx_tuple): 216 | existing_idx = [idx[i] for idx in M] 217 | if index in existing_idx: 218 | return True 219 | return False 220 | -------------------------------------------------------------------------------- /dover_lap/src/mapping/hungarian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import List, Dict, Tuple, Optional 4 | 5 | from dover_lap.libs.utils import groupby 6 | from dover_lap.libs.turn import Turn, merge_turns 7 | from spyder import DER 8 | 9 | from .map_utils import * 10 | 11 | 12 | class HungarianMap: 13 | def __init__(self) -> None: 14 | pass 15 | 16 | def compute_mapping( 17 | self, 18 | turns_list: List[List[Turn]], 19 | ) -> Tuple[Dict[Tuple[int, str], int], np.ndarray]: 20 | """ 21 | Use Hungarian algorithm for label mapping for 2 system special case. 22 | """ 23 | self.turns_list = turns_list 24 | 25 | weights = self._compute_weights() 26 | # Sort the hypotheses by their weights 27 | sorted_idx = weights.argsort().tolist() 28 | self.sorted_turns_list = [self.turns_list[i] for i in sorted_idx] 29 | 30 | cur_turns = self.sorted_turns_list[0] 31 | self.global_mapping = dict() 32 | 33 | for i in range(1, len(self.sorted_turns_list)): 34 | next_turns = self.sorted_turns_list[i] 35 | local_mapping = self._map_pair(cur_turns, next_turns) 36 | cur_turns = self._merge_pair(cur_turns, next_turns, local_mapping) 37 | self._update_global_map(local_mapping) 38 | 39 | if not self._validate_global_mapping(): 40 | raise Exception("Some speakers have not been mapped") 41 | return self.global_mapping, weights 42 | 43 | def _validate_global_mapping(self) -> bool: 44 | for i, turns in enumerate(self.sorted_turns_list): 45 | groups = { 46 | key: list(group) 47 | for key, group in groupby(turns, lambda x: x.speaker_id) 48 | } 49 | for spk in groups: 50 | if (i, spk) not in self.global_mapping: 51 | return False 52 | return True 53 | 54 | def _compute_weights(self) -> np.ndarray: 55 | N = len(self.turns_list) 56 | DERs = np.zeros(N) 57 | for i in range(N): 58 | DER_i = [] 59 | hyp = [ 60 | (turn.speaker_id, turn.onset, turn.offset) 61 | for turn in self.turns_list[i] 62 | ] 63 | for j in range(N): 64 | if i == j: 65 | continue 66 | ref = [ 67 | (turn.speaker_id, turn.onset, turn.offset) 68 | for turn in self.turns_list[j] 69 | ] 70 | der = DER(ref, hyp).der 71 | DER_i.append(der) 72 | DERs[i] = np.mean(DER_i) 73 | return DERs 74 | 75 | def _map_pair( 76 | self, ref_turns: List[Turn], sys_turns: List[Turn] 77 | ) -> Dict[Tuple[int, str], int]: 78 | ref = [(turn.speaker_id, turn.onset, turn.offset) for turn in ref_turns] 79 | sys = [(turn.speaker_id, turn.onset, turn.offset) for turn in sys_turns] 80 | metrics = DER(ref, sys) 81 | ref_map = metrics.ref_map 82 | sys_map = metrics.hyp_map 83 | 84 | label_mapping = {} 85 | for k, v in ref_map.items(): 86 | label_mapping[(0, k)] = v 87 | for k, v in sys_map.items(): 88 | label_mapping[(1, k)] = v 89 | 90 | return label_mapping 91 | 92 | def _merge_pair( 93 | self, 94 | ref_turns: List[Turn], 95 | sys_turns: List[Turn], 96 | label_map: Dict[Tuple[int, int], int], 97 | ) -> List[Turn]: 98 | ref_turns_mapped = [ 99 | Turn( 100 | onset=turn.onset, 101 | offset=turn.offset, 102 | speaker_id=label_map[(0, turn.speaker_id)], 103 | ) 104 | for turn in ref_turns 105 | ] 106 | sys_turns_mapped = [ 107 | Turn( 108 | onset=turn.onset, 109 | offset=turn.offset, 110 | speaker_id=label_map[(1, turn.speaker_id)], 111 | ) 112 | for turn in sys_turns 113 | ] 114 | all_turns = merge_turns(ref_turns_mapped + sys_turns_mapped) 115 | return all_turns 116 | 117 | def _update_global_map(self, local_map: Dict[Tuple[int, str], int]) -> None: 118 | if not self.global_mapping: 119 | self.global_mapping = local_map.copy() 120 | return 121 | new_global_map = {} 122 | max_file_id = 0 123 | for key, old_id in self.global_mapping.items(): 124 | file_id, spk_id = key 125 | max_file_id = max(max_file_id, file_id) 126 | new_global_map[key] = local_map[(0, old_id)] 127 | for key, val in local_map.items(): 128 | file_id, spk_id = key 129 | if file_id == 1: 130 | new_global_map[(max_file_id + 1, spk_id)] = val 131 | self.global_mapping = new_global_map.copy() 132 | -------------------------------------------------------------------------------- /dover_lap/src/mapping/map_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple 2 | 3 | from dover_lap.libs.utils import groupby 4 | from dover_lap.libs.turn import Turn 5 | 6 | 7 | def compute_spk_overlap(ref_spk_turns: List[Turn], sys_spk_turns: List[Turn]) -> float: 8 | """ 9 | Computes 'relative' overlap, i.e. Intersection Over Union 10 | """ 11 | tokens = [] 12 | all_turns = ref_spk_turns + sys_spk_turns 13 | for turn in all_turns: 14 | tokens.append(("BEG", turn.onset)) 15 | tokens.append(("END", turn.offset)) 16 | spk_count = 0 17 | ovl_duration = 0 18 | total_duration = 0 19 | for token in sorted(tokens, key=lambda x: x[1]): 20 | if token[0] == "BEG": 21 | spk_count += 1 22 | if spk_count == 2: 23 | ovl_begin = token[1] 24 | if spk_count == 1: 25 | speech_begin = token[1] 26 | else: 27 | spk_count -= 1 28 | if spk_count == 1: 29 | ovl_duration += token[1] - ovl_begin 30 | if spk_count == 0: 31 | total_duration += token[1] - speech_begin 32 | return ovl_duration / total_duration 33 | 34 | 35 | def get_speaker_keys(turns_list: List[List[Turn]]) -> Dict[Tuple[int, int], str]: 36 | """ 37 | Returns a dictionary which maps a file id (relative) and speaker id (relative) 38 | to absolute speaker id. 39 | """ 40 | speakers_dict = {} 41 | for i, turns in enumerate(turns_list): 42 | turn_groups = { 43 | key: list(group) for key, group in groupby(turns, lambda x: x.speaker_id) 44 | } 45 | for j, key in enumerate(sorted(turn_groups.keys())): 46 | speakers_dict[(i, j)] = key 47 | return speakers_dict -------------------------------------------------------------------------------- /dover_lap/src/voting/__init__.py: -------------------------------------------------------------------------------- 1 | from .average import WeightedAverageVoting -------------------------------------------------------------------------------- /dover_lap/src/voting/average.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import numpy as np 4 | from scipy.ndimage import gaussian_filter1d 5 | from scipy.stats import rankdata 6 | 7 | from dover_lap.libs.turn import Turn 8 | 9 | 10 | class WeightedAverageVoting: 11 | def __init__(self, gaussian_filter_std: float = 0.01) -> None: 12 | self.gaussian_filter_std = gaussian_filter_std 13 | 14 | def get_combined_turns( 15 | self, regions: np.ndarray, start_end: np.ndarray, file_id: str 16 | ) -> List[Turn]: 17 | """ 18 | Implements combination using the DOVER-Lap weighted average voting method. 19 | 20 | :param regions, matrix of shape (num_regions, num_speakers, num_hypothesis). 21 | The value in cell (t, k, n) represents the weight speaker `k` in region `t` 22 | contributed by hypothesis `n`. 23 | :param start_end, list of start and end times for each region 24 | """ 25 | assert ( 26 | regions.shape[0] == start_end.shape[0] 27 | ), "Regions and start_end must have the same number of rows" 28 | 29 | # Sum the weights from all hypotheses 30 | regions = np.sum(regions, axis=2) 31 | 32 | # Apply Gaussian filter to the regions matrix along the T axis 33 | regions = gaussian_filter1d( 34 | regions, sigma=self.gaussian_filter_std, axis=0, mode="nearest" 35 | ) 36 | 37 | # Get the number of speakers in each region 38 | num_spks = np.sum(regions, axis=1).round().astype(int) 39 | 40 | # Rank the weights in each region. We use the min method to break ties. This 41 | # means that if two speakers have the same weight, they will be assigned the 42 | # rank of the lower speaker. Note that we negate the regions matrix because 43 | # rankdata() sorts in ascending order. 44 | spk_ranks_matrix = rankdata(-1 * regions, axis=1, method="min") 45 | 46 | # Create turns list by combining the regions and start_end 47 | combined_turns_list = [] 48 | for i in range(len(spk_ranks_matrix)): 49 | start_time, end_time = start_end[i] 50 | spk_ranks = spk_ranks_matrix[i] 51 | for j, spk_rank in enumerate(spk_ranks): 52 | if spk_rank <= num_spks[i]: 53 | turn = Turn( 54 | onset=start_time, 55 | offset=end_time, 56 | speaker_id=j, 57 | file_id=file_id, 58 | ) 59 | combined_turns_list.append(turn) 60 | 61 | return combined_turns_list 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | intervaltree>=3.0.1 2 | numpy>=1.20.1 3 | click>=7.1.2 4 | scipy>=1.5.4 5 | spy-der>=0.4.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open("requirements.txt") as f: 6 | requirements = f.read().splitlines() 7 | 8 | long_description = open("README.md").read() 9 | 10 | dev_requires = ["pre-commit", "black", "flake8"] 11 | 12 | setup( 13 | name="dover-lap", 14 | version="1.3.2", 15 | author="Desh Raj", 16 | author_email="r.desh26@gmail.com", 17 | description="Combine overlap-aware diarization output RTTMs", 18 | keywords="diarization dover", 19 | url="https://github.com/desh2608/dover-lap", 20 | license="Apache-2.0 License", 21 | packages=find_packages(), 22 | install_requires=requirements, 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | classifiers=[ 26 | "Development Status :: 3 - Alpha", 27 | "Topic :: Utilities", 28 | "License :: OSI Approved :: BSD License", 29 | ], 30 | entry_points={"console_scripts": ["dover-lap=dover_lap.dover_lap:main"]}, 31 | ) 32 | --------------------------------------------------------------------------------