├── requirements.txt ├── pytest.ini ├── vap_turn_taking ├── __init__.py ├── config │ └── example_data.py ├── backchannel.py ├── animation.py ├── events.py ├── utils.py ├── metrics.py ├── hold_shifts.py ├── plot_utils.py └── vap.py ├── setup.py ├── test ├── test_backchannel.py ├── test_hold_shift.py ├── test_vap.py └── test_events.py ├── LICENSE ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torchmetrics 2 | einops 3 | pytest 4 | matplotlib 5 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | vap: test vap.py 4 | backchannel: test backchannel.py 5 | hold_shift: hold/shift extractor test 6 | events: event test 7 | vap: test for Voice Activity Projection module 8 | -------------------------------------------------------------------------------- /vap_turn_taking/__init__.py: -------------------------------------------------------------------------------- 1 | from .vap import ActivityEmb, VAP, VAPLabel 2 | from .hold_shifts import HoldShift 3 | from .backchannel import Backchannel 4 | from .events import TurnTakingEvents 5 | from .metrics import TurnTakingMetrics 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup( 6 | name="vap_turn_taking", 7 | version="0.0.0", 8 | description="VAP (Voice Activity Projection)", 9 | author="erikekst", 10 | author_email="erikekst@kth.se", 11 | url="https://github.com/ErikEkstedt/vap_turn_taking", 12 | packages=["vap_turn_taking"], 13 | ) 14 | -------------------------------------------------------------------------------- /test/test_backchannel.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vap_turn_taking import Backchannel 3 | from vap_turn_taking.config.example_data import event_conf_frames, example 4 | 5 | 6 | @pytest.mark.backchannel() 7 | def test_backchannel(): 8 | 9 | va = example["va"] 10 | bc_label = example["backchannel"] 11 | 12 | bc_kwargs = event_conf_frames["bc"] 13 | bcer = Backchannel(**bc_kwargs) 14 | tt_bc = bcer(va, max_frame=None) 15 | 16 | ndiff = (bc_label != tt_bc["backchannel"]).sum().item() 17 | assert ndiff == 0, f"Backchannel diff {ndiff} != 0" 18 | -------------------------------------------------------------------------------- /test/test_hold_shift.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vap_turn_taking import HoldShift 3 | from vap_turn_taking.config.example_data import example, event_conf_frames 4 | 5 | 6 | @pytest.mark.hold_shift 7 | def test_hold_shifts(): 8 | 9 | hs_kwargs = event_conf_frames["hs"] 10 | HS = HoldShift(**hs_kwargs) 11 | tt = HS(example["va"]) 12 | 13 | hdiff = (tt["hold"] != example["hold"]).sum() 14 | sdiff = (tt["shift"] != example["shift"]).sum() 15 | 16 | assert hdiff == 0, f"Backchannel hold diff {hdiff} != 0" 17 | assert sdiff == 0, f"Backchannel shift diff {sdiff} != 0" 18 | -------------------------------------------------------------------------------- /test/test_vap.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from vap_turn_taking import VAP 3 | from vap_turn_taking.config.example_data import example 4 | 5 | 6 | @pytest.mark.vap 7 | def test_vap_discrete(): 8 | 9 | vapper = VAP(type="discrete") 10 | va = example["va"] 11 | y = vapper.extract_label(va) 12 | assert y.ndim == 2, "Shape error: {y.shape} != (B, N)" 13 | 14 | 15 | @pytest.mark.vap 16 | def test_vap_independent(): 17 | vapper = VAP(type="independent") 18 | y = vapper.extract_label(example["va"]) 19 | assert y.ndim == 4, "Shape error: y.ndim != 4. i.e. not (b, n, 2, 4)" 20 | assert y.shape[-2:] == (2, 4), f"Shape error: {y.shape[-2]} != (..., c, n_bins)" 21 | 22 | 23 | @pytest.mark.vap 24 | def test_vap_comparative(): 25 | vapper = VAP(type="comparative") 26 | va = example["va"] 27 | y = vapper.extract_label(va) 28 | assert y.ndim == 2, "Shape error: {y.shape} != (B, N)" 29 | -------------------------------------------------------------------------------- /test/test_events.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from vap_turn_taking import TurnTakingEvents 4 | from vap_turn_taking.config.example_data import example, event_conf 5 | 6 | 7 | @pytest.mark.events 8 | def test_events(): 9 | 10 | eventer = TurnTakingEvents( 11 | hs_kwargs=event_conf["hs"], 12 | bc_kwargs=event_conf["bc"], 13 | metric_kwargs=event_conf["metric"], 14 | frame_hz=100, 15 | ) 16 | 17 | va = example["va"] 18 | events = eventer(va) 19 | 20 | # Shift/Hold 21 | sdiff = (events["shift"] != example["shift"]).sum() 22 | assert sdiff == 0, f"SHIFT non-zero diff: {sdiff}" 23 | 24 | hdiff = (events["hold"] != example["hold"]).sum() 25 | assert hdiff == 0, f"HOLD non-zero diff: {hdiff}" 26 | 27 | long_diff = (events["long"] != example["long"]).sum() 28 | assert long_diff == 0, f"LONG non-zero diff: {long_diff}" 29 | 30 | short_diff = (events["short"] != example["short"]).sum() 31 | assert short_diff == 0, f"SHORT non-zero diff: {short_diff}" 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 ErikEks 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | pyrightconfig.json 131 | 132 | # custom 133 | assets/ 134 | -------------------------------------------------------------------------------- /vap_turn_taking/config/example_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from vap_turn_taking.utils import time_to_frames 4 | 5 | # TODO:Decide on config format/class/yaml/json 6 | 7 | 8 | # Configs for Events 9 | metric_kwargs = dict( 10 | pad=0, # int, pad on silence (shift/hold) onset used for evaluating\ 11 | dur=0.2, # int, duration off silence (shift/hold) used for evaluating\ 12 | pre_label_dur=0.4, # int, frames prior to Shift-silence for prediction on-active shift 13 | onset_dur=0.2, 14 | min_context=3, 15 | ) 16 | hs_kwargs = dict( 17 | post_onset_shift=1, 18 | pre_offset_shift=1, 19 | post_onset_hold=1, 20 | pre_offset_hold=1, 21 | non_shift_horizon=2, 22 | metric_pad=metric_kwargs["pad"], 23 | metric_dur=metric_kwargs["dur"], 24 | metric_pre_label_dur=metric_kwargs["pre_label_dur"], 25 | metric_onset_dur=metric_kwargs["onset_dur"], 26 | ) 27 | bc_kwargs = dict( 28 | max_duration_frames=1, 29 | pre_silence_frames=1, 30 | post_silence_frames=1, 31 | min_duration_frames=metric_kwargs["onset_dur"], 32 | metric_dur_frames=metric_kwargs["onset_dur"], 33 | metric_pre_label_dur=metric_kwargs["pre_label_dur"], 34 | ) 35 | event_conf = {"hs": hs_kwargs, "bc": bc_kwargs, "metric": metric_kwargs} 36 | ###################################################################################### 37 | frame_hz = 100 38 | event_conf_frames = {} 39 | for k, v in event_conf.items(): 40 | event_conf_frames[k] = {} 41 | for kk, vv in v.items(): 42 | if kk != "non_shift_majority_ratio": 43 | event_conf_frames[k][kk] = time_to_frames(vv, frame_hz) 44 | 45 | ###################################################################################### 46 | a_bc = (585, 660) 47 | as_onset = (840, 950) 48 | a_post_hold = (1000, 1100) 49 | start = 1110 50 | b_bc = (start, start + 80) 51 | bs_onset = (350, 460) 52 | A_segments = [(0, 300), a_bc, as_onset, a_post_hold, (1200, 1360)] 53 | B_segments = [bs_onset, (480, 590), (670, 800), b_bc, (1350, 1590)] 54 | max_frame = max(A_segments[-1][-1], B_segments[-1][-1]) 55 | va = torch.zeros((max_frame, 2), dtype=torch.float) 56 | for start, end in A_segments: 57 | va[start:end, 0] = 1.0 58 | for start, end in B_segments: 59 | va[start:end, 1] = 1.0 60 | # Labels 61 | onset = event_conf_frames["metric"]["onset_dur"] 62 | dur = event_conf_frames["metric"]["dur"] 63 | # BC 64 | bc = torch.zeros_like(va) 65 | bc[a_bc[0] : a_bc[0] + onset, 0] = 1.0 66 | bc[b_bc[0] : b_bc[0] + onset, 1] = 1.0 67 | # S/H 68 | s = torch.zeros_like(va) 69 | s[300 : 300 + dur, 1] = 1.0 70 | s[800 : 800 + dur, 0] = 1.0 71 | # HOLD 72 | h = torch.zeros_like(va) 73 | h[as_onset[-1] : as_onset[-1] + dur, 0] = 1.0 74 | h[bs_onset[-1] : bs_onset[-1] + dur, 1] = 1.0 75 | # Long 76 | long = torch.zeros_like(va) 77 | long[350 : 350 + onset, 1] = 1 78 | long[840 : 840 + onset, 0] = 1 79 | # Long 80 | short = torch.zeros_like(va) 81 | short[a_bc[0] : a_bc[0] + onset, 0] = 1 82 | short[b_bc[0] : b_bc[0] + onset, 1] = 1 83 | # unsqueeze 84 | example = { 85 | "va": va.unsqueeze(0), 86 | "hold": h.unsqueeze(0), 87 | "shift": s.unsqueeze(0), 88 | "backchannel": bc.unsqueeze(0), 89 | "short": short.unsqueeze(0), 90 | "long": long.unsqueeze(0), 91 | } 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAP: Voice Activity Projection 2 | 3 | 4 | WARNING: This is not actively maintained! 5 | 6 | Checkout [VoiceActivityProjection](https://github.com/ErikEkstedt/VoiceActivityProjection) for full model and 'vapper' modules. 7 | The code relevant for this codebase can be in the following files: 8 | - [vap/objective.py](https://github.com/ErikEkstedt/VoiceActivityProjection/blob/main/vap/objective.py) 9 | - [vap/events.py](https://github.com/ErikEkstedt/VoiceActivityProjection/blob/main/vap/events.py) 10 | 11 | ------------------------------------------------- 12 | 13 | # VAP: Voice Activity Projection 14 | 15 | Voice Activity Projection module used in the paper [Voice Activity Projection: Self-supervised Learning of Turn-taking Events](). 16 | 17 | * VAP-head 18 | - An NN 'layer' which extracts VAP-labels (discrete, independent, comparative), projection-windows to states, define zero-shot probabilities. 19 | * Events 20 | - Automatically extract turn-taking events given Voice Activity (e.g. tensor: `(B, N_FRAMES, 2)`) for two speakers 21 | * Metrics 22 | - [Torchmetrics](https://torchmetrics.readthedocs.io/en/latest/) 23 | 24 | 25 | ## Installation 26 | 27 | Install `vap_turn_taking` 28 | 29 | * preferably using an environment [miniconda](https://docs.conda.io/en/latest/miniconda.html) 30 | * Including a working installation of [pytorch](https://pytorch.org/) 31 | * [Optional] (for videos) Install FFMPEG: `conda install -c conda-forge ffmpeg` 32 | * Install dependencies: `pip install -r requirements.txt` 33 | * Install package: `pip install -e . ` 34 | 35 | 36 | ## VAP 37 | See section 2 of the [paper](). 38 | 39 | The Voice Acticity Projection module extract model ('discrete', 'independent', 40 | 'comparative') VA-labels and given voice activity and model logits-outputs, 41 | extracts turn-taking ("zero-shot") probabilities. 42 | 43 | ```python 44 | from vap_turn_taking.config.example_data import example 45 | from vap_turn_taking import VAP 46 | 47 | 48 | vapper = VAP(type="discrete") 49 | 50 | # example of voice activity for 2 speakers 51 | va = example['va'] # Voice Activity (Batch, N_Frames, 2) 52 | 53 | 54 | # Extract labels: Voice Acticity Projection windows 55 | # Discrete: (B, N_frames), class indices 56 | # Independent: (B, N_frames, 2, N_bins), binary vap_bins 57 | # Comaparative: (B, N_frames), float scalar 58 | y = vapper.extract_label(va) 59 | 60 | # Associated logits (discrete/independent/comparative) 61 | logits = model(INPUTS) # same shape as the labels 62 | 63 | 64 | # Get "zero-shot" probabilites 65 | turn_taking_probs = vapper(logits, va) # keys: "p", "p_bc" 66 | # turn_taking_probs['p'], (B, N_frames, 2) -> probability of next speaker 67 | # turn_taking_probs['p_bc'], (B, N_frames, 2) -> probability of backchannel prediction 68 | ``` 69 | 70 | 71 | ## Events 72 | 73 | See section 3 of the [paper](). 74 | 75 | The module which extract events from a Voice Activity representation used to 76 | calculate scores over particular frames of interest. 77 | 78 | ```python 79 | from vap_turn_taking.config.example_data import example, event_conf 80 | from vap_turn_taking import TurnTakingEvents 81 | 82 | 83 | # example of voice activity for 2 speakers 84 | va = example['va'] # Voice Activity (Batch, N_Frames, 2) 85 | 86 | 87 | # Class to extract turn-taking events 88 | eventer = TurnTakingEvents( 89 | hs_kwargs=event_conf["hs"], 90 | bc_kwargs=event_conf["bc"], 91 | metric_kwargs=event_conf["metric"], 92 | frame_hz=100, 93 | ) 94 | 95 | # extract events from binary voice activity features 96 | events = eventer(va, max_frame=None) 97 | 98 | # all events are binary representations of size (B, N_frames, 2) 99 | # where 1 indicates an event relevant frame. 100 | # events.keys(): [ 101 | # 'shift', 102 | # 'hold', 103 | # 'short', 104 | # 'long', 105 | # 'predict_shift_pos', 106 | # 'predict_shift_neg', 107 | # 'predict_bc_pos', 108 | # 'predict_bc_neg' 109 | # ] 110 | ``` 111 | 112 | Where the `event_kwargs` can be 113 | 114 | ```python 115 | # Configs for Events 116 | metric_kwargs = dict( 117 | pad=0, # int, pad on silence (shift/hold) onset used for evaluating\ 118 | dur=0.2, # int, duration off silence (shift/hold) used for evaluating\ 119 | pre_label_dur=0.4, # int, frames prior to Shift-silence for prediction on-active shift 120 | onset_dur=0.2, 121 | min_context=3, 122 | ) 123 | hs_kwargs = dict( 124 | post_onset_shift=1, 125 | pre_offset_shift=1, 126 | post_onset_hold=1, 127 | pre_offset_hold=1, 128 | non_shift_horizon=2, 129 | metric_pad=metric_kwargs["pad"], 130 | metric_dur=metric_kwargs["dur"], 131 | metric_pre_label_dur=metric_kwargs["pre_label_dur"], 132 | metric_onset_dur=metric_kwargs["onset_dur"], 133 | ) 134 | bc_kwargs = dict( 135 | max_duration_frames=1, 136 | pre_silence_frames=1, 137 | post_silence_frames=1, 138 | min_duration_frames=metric_kwargs["onset_dur"], 139 | metric_dur_frames=metric_kwargs["onset_dur"], 140 | metric_pre_label_dur=metric_kwargs["pre_label_dur"], 141 | ) 142 | event_conf = {"hs": hs_kwargs, "bc": bc_kwargs, "metric": metric_kwargs} 143 | ``` 144 | 145 | 146 | ## Metrics 147 | 148 | See section 3 of the [paper](). 149 | 150 | Calculates metrics during training/evaluation given the `turn_taking_probs` 151 | from the `VAP`+model-output and the events from `TurnTakingEvents`. Built using [torchmetrics](https://torchmetrics.readthedocs.io/en/latest/). 152 | 153 | ```python 154 | from vap_turn_taking import TurnTakingMetrics 155 | from vap_turn_taking.config.example_data import example, event_conf 156 | 157 | 158 | va = example['va'] # Voice Activity (Batch, N_Frames, 2) 159 | 160 | 161 | metric = TurnTakingMetrics( 162 | hs_kwargs=event_conf["hs"], 163 | bc_kwargs=event_conf["bc"], 164 | metric_kwargs=event_conf["metric"], 165 | bc_pred_pr_curve=True, 166 | shift_pred_pr_curve=True, 167 | long_short_pr_curve=True, 168 | frame_hz=100, 169 | ) 170 | 171 | # Forward pass through a model, extract events, extract turn-taking probabilites 172 | logits = model(INPUTS) 173 | events = eventer(va, max_frame=None) 174 | turn_taking_probs = vapper(logits, va) # keys: "p", "p_bc" 175 | 176 | # Update metrics 177 | metric.update( 178 | p=turn_taking_probs["p"], 179 | bc_pred_probs=turn_taking_probs.get("bc_prediction", None), 180 | events=events, 181 | ) 182 | 183 | # Compute: finalize/aggregates the scores (usually used after epoch is finished) 184 | result = metric.compute() 185 | 186 | # Resets the metrics (usually used before starting a new epoch) 187 | result = metric.reset() 188 | ``` 189 | -------------------------------------------------------------------------------- /vap_turn_taking/backchannel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vap_turn_taking.utils import find_island_idx_len 3 | from vap_turn_taking.hold_shifts import get_dialog_states, get_last_speaker 4 | 5 | 6 | def find_isolated_within(vad, prefix_frames, max_duration_frames, suffix_frames): 7 | """ 8 | ... <= prefix_frames (silence) | <= max_duration_frames (active) | <= suffix_frames (silence) ... 9 | """ 10 | 11 | isolated = torch.zeros_like(vad) 12 | for b, vad_tmp in enumerate(vad): 13 | for speaker in [0, 1]: 14 | starts, durs, vals = find_island_idx_len(vad_tmp[..., speaker]) 15 | for step in range(1, len(starts) - 1): 16 | # Activity condition: current step is active 17 | if vals[step] == 0: 18 | continue 19 | 20 | # Prefix condition: 21 | # check that current active step comes after a certain amount of inactivity 22 | if durs[step - 1] < prefix_frames: 23 | continue 24 | 25 | # Suffix condition 26 | # check that current active step comes after a certain amount of inactivity 27 | if durs[step + 1] < suffix_frames: 28 | continue 29 | 30 | current_dur = durs[step] 31 | if current_dur <= max_duration_frames: 32 | start = starts[step] 33 | end = start + current_dur 34 | isolated[b, start:end, speaker] = 1.0 35 | return isolated 36 | 37 | 38 | class Backchannel: 39 | def __init__( 40 | self, 41 | max_duration_frames, 42 | min_duration_frames, 43 | pre_silence_frames, 44 | post_silence_frames, 45 | metric_dur_frames, 46 | metric_pre_label_dur, 47 | ): 48 | 49 | assert ( 50 | metric_dur_frames <= max_duration_frames 51 | ), "`metric_dur_frames` must be less than `max_duration_frames`" 52 | self.max_duration_frames = max_duration_frames 53 | self.min_duration_frames = min_duration_frames 54 | self.pre_silence_frames = pre_silence_frames 55 | self.post_silence_frames = post_silence_frames 56 | self.metric_dur_frames = metric_dur_frames 57 | self.metric_pre_label_dur = metric_pre_label_dur 58 | 59 | def __repr__(self): 60 | s = "\nBackchannel" 61 | s += f"\n max_duration_frames: {self.max_duration_frames}" 62 | s += f"\n pre_silence_frames: {self.pre_silence_frames}" 63 | s += f"\n post_silence_frames: {self.post_silence_frames}" 64 | return s 65 | 66 | def backchannel(self, vad, last_speaker, max_frame=None, min_context=0): 67 | """ 68 | Finds backchannel based on VAD signal. Iterates over batches and speakers. 69 | 70 | Extracts segments of activity/non-activity to find backchannels. 71 | 72 | Backchannel Conditions 73 | 74 | * Backchannel activity must be shorter than `self.max_duration_frames` 75 | * Backchannel activity must follow activity from the other speaker 76 | * Silence prior to backchannel, in the "backchanneler" channel, must be greater than `self.pre_silence_frames` 77 | * Silence after backchannel, in the "backchanneler" channel, must be greater than `self.pre_silence_frames` 78 | """ 79 | 80 | bc_oh = torch.zeros_like(vad) 81 | pre_bc_oh = torch.zeros_like(vad) 82 | for b, vad_tmp in enumerate(vad): 83 | 84 | for speaker in [0, 1]: 85 | other_speaker = 0 if speaker == 1 else 1 86 | 87 | starts, durs, vals = find_island_idx_len(vad_tmp[..., speaker]) 88 | for step in range(1, len(starts) - 1): 89 | # Activity condition: current step is active 90 | if vals[step] == 0: 91 | continue 92 | 93 | # Activity duration condition: segment must be shorter than 94 | # a certain number of frames 95 | if durs[step] > self.max_duration_frames: 96 | continue 97 | 98 | if durs[step] < self.min_duration_frames: 99 | continue 100 | 101 | start = starts[step] 102 | 103 | # Shift-ish condition: 104 | # Was the other speaker active prior to this `backchannel` candidate? 105 | # If not than this is a short IPU in the middle of a turn 106 | pre_speaker_cond = last_speaker[b, start - 1] == other_speaker 107 | if not pre_speaker_cond: 108 | continue 109 | 110 | # Prefix condition: 111 | # check that current active step comes after a certain amount of inactivity 112 | if durs[step - 1] < self.pre_silence_frames: 113 | continue 114 | 115 | # Suffix condition 116 | # check that current active step comes after a certain amount of inactivity 117 | if durs[step + 1] < self.post_silence_frames: 118 | continue 119 | 120 | # Add segment as a backchanel 121 | end = starts[step] + durs[step] 122 | if self.metric_dur_frames > 0: 123 | end = starts[step] + self.metric_dur_frames 124 | 125 | # Max Frame condition: 126 | # can't have event outside of predictable window 127 | if max_frame is not None: 128 | if end >= max_frame: 129 | continue 130 | 131 | # Min Context condition: 132 | if starts[step] < min_context: 133 | continue 134 | 135 | bc_oh[b, starts[step] : end, speaker] = 1.0 136 | 137 | # Min Context condition: 138 | if (starts[step] - self.metric_pre_label_dur) < min_context: 139 | continue 140 | 141 | pre_bc_oh[ 142 | b, 143 | starts[step] - self.metric_pre_label_dur : starts[step], 144 | speaker, 145 | ] = 1.0 146 | return bc_oh, pre_bc_oh 147 | 148 | def __call__(self, vad, last_speaker=None, ds=None, max_frame=None, min_context=0): 149 | 150 | if ds is None: 151 | ds = get_dialog_states(vad) 152 | 153 | if last_speaker is None: 154 | last_speaker = get_last_speaker(vad, ds) 155 | 156 | bc_oh, pre_bc = self.backchannel( 157 | vad, last_speaker, max_frame=max_frame, min_context=min_context 158 | ) 159 | return {"backchannel": bc_oh, "pre_backchannel": pre_bc} 160 | 161 | 162 | if __name__ == "__main__": 163 | import matplotlib.pyplot as plt 164 | from vap_turn_taking.plot_utils import plot_vad_oh 165 | 166 | BS = Backhannel(**bs_dict) 167 | tt_bc = BS(va) 168 | 169 | (tt_bc["backchannel"] != bc).sum() 170 | 171 | n_rows = 4 172 | n_cols = 4 173 | fig, ax = plt.subplots(n_rows, n_cols, sharey=True, sharex=True, figsize=(16, 4)) 174 | b = 0 175 | for row in range(n_rows): 176 | for col in range(n_cols): 177 | _ = plot_vad_oh(vad[b], ax=ax[row, col]) 178 | _ = plot_vad_oh( 179 | bc["backchannel"][b], 180 | ax=ax[row, col], 181 | colors=["purple", "purple"], 182 | alpha=0.8, 183 | ) 184 | b += 1 185 | if b == vad.shape[0]: 186 | break 187 | if b == vad.shape[0]: 188 | break 189 | plt.pause(0.1) 190 | -------------------------------------------------------------------------------- /vap_turn_taking/animation.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname 2 | from os import makedirs 3 | 4 | import torch 5 | import torchaudio 6 | import subprocess 7 | from tqdm import tqdm 8 | 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | from matplotlib import animation 12 | from matplotlib.patches import Rectangle 13 | 14 | from vap_turn_taking.plot_utils import plot_vad_oh 15 | 16 | mpl.use("agg") 17 | 18 | """ 19 | This works if FFMPEG is installed correctly. The below line should work. 20 | 21 | conda install -c conda-forge ffmpeg 22 | 23 | """ 24 | 25 | 26 | class VAPanimation: 27 | def __init__( 28 | self, 29 | p, 30 | p_bc, 31 | p_class, 32 | vap_bins, 33 | x, 34 | va, 35 | events=None, 36 | window_duration=10, 37 | frame_hz=100, 38 | sample_rate=16000, 39 | fps=20, 40 | dpi=200, 41 | bin_frames=[20, 40, 60, 80], 42 | ) -> None: 43 | """""" 44 | # Model output 45 | self.p = p 46 | self.p_bc = p_bc 47 | self.p_class = p_class 48 | 49 | self.vap_bins = vap_bins 50 | self.weighted_oh = self._weighted_oh(p_class, vap_bins) 51 | self.best_p, self.best_idx = p_class.max(dim=-1) 52 | 53 | # Model input 54 | self.x = x # Waveform 55 | self.va = va # Voice Activity 56 | self.events = events # events 57 | 58 | # Parameters 59 | self.frame_hz = frame_hz 60 | self.sample_rate = sample_rate 61 | self.window_duration = window_duration 62 | self.window_frames = self.window_duration * self.frame_hz 63 | self.center = self.window_frames // 2 64 | self.bin_frames = bin_frames 65 | 66 | # Animation params 67 | self.dpi = dpi 68 | self.fps = fps 69 | self.frame_step = int(100.0 / self.fps) 70 | 71 | self.plot_kwargs = { 72 | "A": {"color": "b"}, 73 | "B": {"color": "orange"}, 74 | "va": {"alpha": 0.6}, 75 | "bc": {"color": "darkgreen", "alpha": 0.6}, 76 | "vap": {"ylim": [-0.5, 0.5], "width": 3}, 77 | "current": {"width": 5}, 78 | } 79 | 80 | self.fig, self.ax = plt.subplots(1, 1, figsize=(12, 4)) 81 | self.pred_ax = self.ax.twinx() 82 | self.vap_ax = self.ax.twinx() 83 | self.vap_patches = [] 84 | self.draw_vap_patches = True 85 | 86 | self.draw_static() 87 | self.started = False 88 | 89 | def _weighted_oh(self, p_class, vap_bins): 90 | weighted_oh = p_class.unsqueeze(-1).unsqueeze(-1) * vap_bins 91 | weighted_oh = weighted_oh.sum(dim=1) # sum all class onehot 92 | return weighted_oh 93 | 94 | def set_axis_lim(self): 95 | self.ax.set_xlim([0, self.window_frames]) 96 | # PROBS 97 | self.pred_ax.set_xlim([0, self.window_frames]) 98 | self.pred_ax.set_yticks([]) 99 | # VAP 100 | self.vap_ax.set_ylim([-1, 1]) 101 | self.vap_ax.set_yticks([]) 102 | 103 | def draw_static(self): 104 | self.current_line = self.pred_ax.vlines( 105 | self.center, 106 | ymin=-1, 107 | ymax=1, 108 | color="r", 109 | linewidth=self.plot_kwargs["current"]["width"], 110 | ) 111 | 112 | # VAP BOX 113 | s = (torch.tensor(self.bin_frames).cumsum(0) + self.center).tolist() 114 | ymin, ymax = self.plot_kwargs["vap"]["ylim"] 115 | w = s[-1] - self.center 116 | h = ymax - ymin 117 | # white background 118 | 119 | vap_background = Rectangle( 120 | xy=[self.center, ymin], width=w, height=h, color="w", alpha=1 121 | ) 122 | self.vap_ax.add_patch(vap_background) 123 | self.vap_ax.vlines( 124 | s, 125 | ymin=ymin, 126 | ymax=ymax, 127 | color="k", 128 | linewidth=self.plot_kwargs["vap"]["width"], 129 | ) 130 | self.vap_ax.plot( 131 | [self.center, s[-1]], 132 | [ymin, ymin], 133 | color="k", 134 | linewidth=self.plot_kwargs["vap"]["width"], 135 | ) 136 | self.vap_ax.plot( 137 | [self.center, s[-1]], 138 | [ymax, ymax], 139 | color="k", 140 | linewidth=self.plot_kwargs["vap"]["width"], 141 | ) 142 | 143 | def clear_ax(self): 144 | self.ax.cla() 145 | self.pa.remove() 146 | self.pb.remove() 147 | self.p_bc_a.remove() 148 | self.p_bc_b.remove() 149 | 150 | for i in range(len(self.vap_patches)): 151 | self.vap_patches[i].remove() 152 | 153 | def draw_step(self, step=0): 154 | if not self.started: 155 | self.started = True 156 | else: 157 | self.clear_ax() 158 | 159 | end = step + self.window_frames 160 | 161 | _ = plot_vad_oh( 162 | self.va[step:end], ax=self.ax, alpha=self.plot_kwargs["va"]["alpha"] 163 | ) 164 | 165 | # Draw probalitiy curves 166 | (self.pa,) = self.pred_ax.plot( 167 | self.p[step:end, 0], color=self.plot_kwargs["A"]["color"] 168 | ) 169 | (self.p_bc_a,) = self.pred_ax.plot( 170 | self.p_bc[step:end, 0], color=self.plot_kwargs["bc"]["color"] 171 | ) 172 | (self.pb,) = self.pred_ax.plot( 173 | self.p[step:, 1] - 1, color=self.plot_kwargs["B"]["color"] 174 | ) 175 | (self.p_bc_b,) = self.pred_ax.plot( 176 | self.p_bc[step:end, 1] - 1, color=self.plot_kwargs["bc"]["color"] 177 | ) 178 | 179 | # draw weighted oh projection 180 | 181 | h = self.plot_kwargs["vap"]["ylim"][-1] 182 | 183 | jj = 0 184 | for speaker, sp_color in zip( 185 | [0, 1], [self.plot_kwargs["A"]["color"], self.plot_kwargs["B"]["color"]] 186 | ): 187 | bf_cum = 0 188 | for bin, bf in enumerate(self.bin_frames): 189 | 190 | alpha = self.weighted_oh[step + self.center, speaker, bin].item() 191 | 192 | if self.draw_vap_patches: 193 | start = self.center + bf_cum 194 | vap_patch = Rectangle( 195 | xy=[start, -h * speaker], 196 | width=bf, 197 | height=h, 198 | color=sp_color, 199 | alpha=alpha, 200 | ) 201 | # self.vap_patches.append(vap_patch) 202 | self.vap_ax.add_patch(vap_patch) 203 | else: 204 | self.vap_ax.patches[jj + 1].set_alpha(alpha) 205 | # self.vap_patches.append(p) 206 | bf_cum += bf 207 | jj += 1 208 | 209 | self.draw_vap_patches = False 210 | 211 | def update(self, step): 212 | self.draw_step(step) 213 | self.set_axis_lim() 214 | return [] 215 | 216 | def ffmpeg_call(self, out_path, vid_path, wav_path): 217 | """ 218 | Overlay the static image on top of the video (saved with transparency) and 219 | adding the audio. 220 | 221 | Arguments: 222 | vid_path: path to temporary dynamic video file 223 | wav_path: path to temporary audio file 224 | img_path: path to temporary static image 225 | out_path: path to save final video to 226 | """ 227 | cmd = [ 228 | "ffmpeg", 229 | "-loglevel", 230 | "error", 231 | "-y", 232 | "-i", 233 | vid_path, 234 | "-i", 235 | wav_path, 236 | "-vcodec", 237 | "libopenh264", 238 | out_path, 239 | ] 240 | p = subprocess.Popen(cmd, stdin=subprocess.PIPE) 241 | p.communicate() 242 | 243 | def save_video(self, path="test.mp4"): 244 | tmp_video_path = "/tmp/vap_video_ani.mp4" 245 | tmp_wav_path = "/tmp/vap_video_audio.wav" 246 | n_frames = self.p.shape[0] - self.center 247 | 248 | if len(dirname(path)) > 0: 249 | makedirs(dirname(path), exist_ok=True) 250 | 251 | sample_offset = int(self.sample_rate * self.center / self.frame_hz) 252 | 253 | # SAVE tmp waveform 254 | torchaudio.save( 255 | tmp_wav_path, 256 | self.x[sample_offset:].unsqueeze(0), 257 | sample_rate=self.sample_rate, 258 | ) 259 | 260 | # Save matplot video 261 | moviewriter = animation.FFMpegWriter( 262 | fps=self.fps # , codec="libopenh264", extra_args=["-threads", "16"] 263 | ) 264 | 265 | with moviewriter.saving(self.fig, tmp_video_path, dpi=self.dpi): 266 | for step in tqdm(range(0, n_frames, self.frame_step)): 267 | _ = self.update(step) 268 | moviewriter.grab_frame() 269 | 270 | self.ffmpeg_call(path, tmp_video_path, tmp_wav_path) 271 | 272 | 273 | if __name__ == "__main__": 274 | 275 | from argparse import ArgumentParser 276 | 277 | parser = ArgumentParser() 278 | parser.add_argument("--video_data", type=str) 279 | parser.add_argument("--filename", type=str, default="video.mp4") 280 | args = parser.parse_args() 281 | 282 | if not args.filename.endswith(".mp4"): 283 | args.filename += ".mp4" 284 | 285 | data = torch.load(args.video_data) 286 | # x = torch.rand(1, 16000) # Waveform 287 | # events = {"shift", "hold", "backchannel"} # Events 288 | # va = torch.rand(1, 100, 2) # Voice Activity 289 | # p = torch.rand(1, 100, 2) # Next speaker probs 290 | # p_bc = torch.rand(1, 100, 2) # Backchannel probs 291 | # p_class = torch.rand(1, 100) # Backchannel probs 292 | # vap_bins = torch.rand(256, 2, 4) # the binary representation of the bin window 293 | 294 | # Save video 295 | ani = VAPanimation( 296 | p=data["p"], 297 | p_bc=data["p_bc"], 298 | p_class=data["logits"].softmax(-1), 299 | vap_bins=data["vap_bins"], 300 | x=data["waveform"], 301 | va=data["va"], 302 | events=None, 303 | fps=20, 304 | ) 305 | ani.save_video(args.filename) 306 | -------------------------------------------------------------------------------- /vap_turn_taking/events.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy.random as np_random 3 | 4 | from vap_turn_taking.backchannel import Backchannel 5 | from vap_turn_taking.hold_shifts import HoldShift 6 | from vap_turn_taking.utils import ( 7 | time_to_frames, 8 | find_island_idx_len, 9 | get_dialog_states, 10 | get_last_speaker, 11 | ) 12 | 13 | 14 | class TurnTakingEvents: 15 | def __init__( 16 | self, 17 | hs_kwargs, 18 | bc_kwargs, 19 | metric_kwargs, 20 | frame_hz=100, 21 | ): 22 | self.frame_hz = frame_hz 23 | 24 | # Times to frames 25 | self.metric_kwargs = self.kwargs_to_frames(metric_kwargs, frame_hz) 26 | self.hs_kwargs = self.kwargs_to_frames(hs_kwargs, frame_hz) 27 | self.bc_kwargs = self.kwargs_to_frames(bc_kwargs, frame_hz) 28 | 29 | # values for metrics 30 | self.metric_min_context = self.metric_kwargs["min_context"] 31 | self.metric_pad = self.metric_kwargs["pad"] 32 | self.metric_dur = self.metric_kwargs["dur"] 33 | self.metric_onset_dur = self.metric_kwargs["onset_dur"] 34 | self.metric_pre_label_dur = self.metric_kwargs["pre_label_dur"] 35 | 36 | self.HS = HoldShift(**self.hs_kwargs) 37 | self.BS = Backchannel(**self.bc_kwargs) 38 | 39 | def kwargs_to_frames(self, kwargs, frame_hz): 40 | new_kwargs = {} 41 | for k, v in kwargs.items(): 42 | new_kwargs[k] = time_to_frames(v, frame_hz) 43 | return new_kwargs 44 | 45 | def __repr__(self): 46 | s = "TurnTakingEvents\n" 47 | s += str(self.HS) + "\n" 48 | s += str(self.BS) 49 | return s 50 | 51 | def count_occurances(self, x): 52 | n = 0 53 | for b in range(x.shape[0]): 54 | for sp in [0, 1]: 55 | _, _, v = find_island_idx_len(x[b, :, sp]) 56 | n += (v == 1).sum().item() 57 | return n 58 | 59 | def sample_negative_segments(self, x, n): 60 | """ 61 | Used to pick a subset of negative segments. 62 | That is on events where the negatives are constrained in certain 63 | single chunk segments. 64 | 65 | Used to sample negatives for LONG/SHORT prediction. 66 | 67 | all start onsets result in either longer or shorter utterances. 68 | Utterances defined as backchannels are considered short and utterances 69 | after pauses or at shifts are considered long. 70 | 71 | """ 72 | neg_candidates = [] 73 | for b in range(x.shape[0]): 74 | for sp in [0, 1]: 75 | starts, durs, v = find_island_idx_len(x[b, :, sp]) 76 | 77 | starts = starts[v == 1] 78 | durs = durs[v == 1] 79 | for s, d in zip(starts, durs): 80 | neg_candidates.append([b, s, s + d, sp]) 81 | 82 | sampled_negs = torch.arange(len(neg_candidates)) 83 | if len(neg_candidates) > n: 84 | sampled_negs = np_random.choice(sampled_negs, size=n, replace=False) 85 | 86 | negs = torch.zeros_like(x) 87 | for ni in sampled_negs: 88 | b, s, e, sp = neg_candidates[ni] 89 | negs[b, s:e, sp] = 1.0 90 | 91 | return negs.float() 92 | 93 | def sample_negatives(self, x, n, dur): 94 | """ 95 | 96 | Choose negative segments from x which contains long stretches of 97 | possible starts of the negative segments. 98 | 99 | Used to sample negatives from NON-SHIFTS which represent longer segments 100 | where every frame is a possible prediction point. 101 | 102 | """ 103 | 104 | onset_pad_min = 3 105 | onset_pad_max = 10 106 | 107 | neg_candidates = [] 108 | for b in range(x.shape[0]): 109 | for sp in [0, 1]: 110 | starts, durs, v = find_island_idx_len(x[b, :, sp]) 111 | 112 | starts = starts[v == 1] 113 | durs = durs[v == 1] 114 | 115 | # Min context condition 116 | durs = durs[starts >= self.metric_min_context] 117 | starts = starts[starts >= self.metric_min_context] 118 | 119 | # Over minimum duration condition 120 | starts = starts[durs > dur] 121 | durs = durs[durs > dur] 122 | 123 | if len(starts) == 0: 124 | continue 125 | 126 | for s, d in zip(starts, durs): 127 | # end of valid frames minus duration of concurrent segment 128 | end = s + d - dur 129 | 130 | if end - s <= onset_pad_min: 131 | onset_pad = 0 132 | elif end - s <= onset_pad_max: 133 | onset_pad = onset_pad_min 134 | else: 135 | onset_pad = torch.randint(onset_pad_min, onset_pad_max, (1,))[ 136 | 0 137 | ].item() 138 | 139 | for neg_start in torch.arange(s + onset_pad, end, dur): 140 | neg_candidates.append([b, neg_start, sp]) 141 | 142 | sampled_negs = torch.arange(len(neg_candidates)) 143 | if len(neg_candidates) > n: 144 | sampled_negs = np_random.choice(sampled_negs, size=n, replace=False) 145 | 146 | negs = torch.zeros_like(x) 147 | for ni in sampled_negs: 148 | b, s, sp = neg_candidates[ni] 149 | negs[b, s : s + dur, sp] = 1.0 150 | return negs.float() 151 | 152 | def __call__(self, vad, max_frame=None): 153 | ds = get_dialog_states(vad) 154 | last_speaker = get_last_speaker(vad, ds) 155 | 156 | # TODO: 157 | # TODO: having all events as a list/dict with (b, start, end, speaker) may be very much faster? 158 | # TODO: 159 | 160 | # HOLDS/SHIFTS: 161 | # shift, pre_shift, long_shift_onset, 162 | # hold, pre_hold, long_hold_onset, 163 | # shift_overlap, pre_shift_overlap, non_shift 164 | tt = self.HS( 165 | vad=vad, ds=ds, max_frame=max_frame, min_context=self.metric_min_context 166 | ) 167 | 168 | # Backchannels: backchannel, pre_backchannel 169 | bcs = self.BS( 170 | vad=vad, 171 | last_speaker=last_speaker, 172 | max_frame=max_frame, 173 | min_context=self.metric_min_context, 174 | ) 175 | 176 | ####################################################### 177 | # LONG/SHORT 178 | ####################################################### 179 | # Investigate the model output at the start of an IPU 180 | # where SHORT segments are "backchannel" and LONG har onset on new TURN (SHIFTs) 181 | # or onset of HOLD ipus 182 | short = bcs["backchannel"] 183 | long = self.sample_negative_segments(tt["long_shift_onset"], 1000) 184 | 185 | ####################################################### 186 | # Predict shift 187 | ####################################################### 188 | # Pos: window, on activity, prior to EOT before a SHIFT 189 | # Neg: Sampled from NON-SHIFT, on activity. 190 | n_predict_shift = self.count_occurances(tt["pre_shift"]) 191 | if n_predict_shift == 0: 192 | predict_shift_neg = torch.zeros_like(tt["pre_shift"]) 193 | else: 194 | # NON-SHIFT where someone is active 195 | activity = ds == 0 # only A 196 | activity = torch.logical_or(activity, ds == 3) # AND only B 197 | activity = activity[:, : tt["non_shift"].shape[1]].unsqueeze(-1) 198 | non_shift_on_activity = torch.logical_and(tt["non_shift"], activity) 199 | predict_shift_neg = self.sample_negatives( 200 | non_shift_on_activity, n_predict_shift, dur=self.metric_pre_label_dur 201 | ) 202 | 203 | ####################################################### 204 | # Predict backchannels 205 | ####################################################### 206 | # Pos: 0.5 second prior a backchannel 207 | # Neg: Sampled from NON-SHIFT, everywhere 208 | n_pre_bc = self.count_occurances(bcs["pre_backchannel"]) 209 | if n_pre_bc == 0: 210 | predict_bc_neg = torch.zeros_like(bcs["pre_backchannel"]) 211 | else: 212 | predict_bc_neg = self.sample_negatives( 213 | tt["non_shift"], n_pre_bc, dur=self.metric_pre_label_dur 214 | ) 215 | 216 | # return tt 217 | return { 218 | "shift": tt["shift"][:, :max_frame], 219 | "hold": tt["hold"][:, :max_frame], 220 | "short": short[:, :max_frame], 221 | "long": long[:, :max_frame], 222 | "predict_shift_pos": tt["pre_shift"][:, :max_frame], 223 | "predict_shift_neg": predict_shift_neg[:, :max_frame], 224 | "predict_bc_pos": bcs["pre_backchannel"][:, :max_frame], 225 | "predict_bc_neg": predict_bc_neg[:, :max_frame], 226 | } 227 | 228 | 229 | if __name__ == "__main__": 230 | import matplotlib.pyplot as plt 231 | from vap_turn_taking.config.example_data import example, event_conf 232 | from vap_turn_taking.plot_utils import plot_vad_oh, plot_event 233 | 234 | eventer = TurnTakingEvents( 235 | hs_kwargs=event_conf["hs"], 236 | bc_kwargs=event_conf["bc"], 237 | metric_kwargs=event_conf["metric"], 238 | frame_hz=100, 239 | ) 240 | va = example["va"] 241 | events = eventer(va, max_frame=None) 242 | print("long: ", (events["long"] != example["long"]).sum()) 243 | print("short: ", (events["short"] != example["short"]).sum()) 244 | print("shift: ", (events["shift"] != example["shift"]).sum()) 245 | print("hold: ", (events["hold"] != example["hold"]).sum()) 246 | # for k, v in events.items(): 247 | # if isinstance(v, torch.Tensor): 248 | # print(f"{k}: {tuple(v.shape)}") 249 | # else: 250 | # print(f"{k}: {v}") 251 | fig, ax = plot_vad_oh(va[0]) 252 | # _, ax = plot_event(events["shift"][0], ax=ax) 253 | _, ax = plot_event(events["hold"][0], color=["r", "r"], ax=ax) 254 | # _, ax = plot_event(events["short"][0], ax=ax) 255 | # _, ax = plot_event(events["long"][0], color=['r', 'r'], ax=ax) 256 | # _, ax = plot_event(example['short'][0], color=["g", "g"], ax=ax) 257 | # _, ax = plot_event(example['long'][0], color=["r", "r"], ax=ax) 258 | _, ax = plot_event(example["hold"][0], color=["b", "b"], ax=ax) 259 | # _, ax = plot_event(example['shift'][0], color=["g", "g"], ax=ax) 260 | # _, ax = plot_event(example['short'][0], color=["r", "r"], ax=ax) 261 | # _, ax = plot_event(example['long'][0], color=["r", "r"], ax=ax) 262 | # _, ax = plot_event(bc[0], color=["b", "b"], ax=ax) 263 | # _, ax = plot_event(tt["shift_overlap"][0], ax=ax) 264 | # _, ax = plot_event(events["short"][0], color=["b", "b"], alpha=0.2, ax=ax) 265 | # _, ax = plot_event(tt_bc["pre_backchannel"][0], alpha=0.2, ax=ax) 266 | # _, ax = plot_event(tt["hold"][0], color=["r", "r"], ax=ax) 267 | # _, ax = plot_event(tt['pre_shift'][0], color=['g', 'g'], alpha=0.2, ax=ax) 268 | # _, ax = plot_event(tt['pre_hold'][0], color=['r', 'r'], alpha=0.2, ax=ax) 269 | # _, ax = plot_event(tt['long_shift_onset'][0], color=['r', 'r'], alpha=0.2, ax=ax) 270 | # _, ax = plot_event(events["non_shift"][0], color=["r", "r"], alpha=0.2, ax=ax) 271 | plt.pause(0.1) 272 | -------------------------------------------------------------------------------- /vap_turn_taking/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | 5 | 6 | def time_to_frames(time, frame_hz): 7 | if isinstance(time, list): 8 | time = torch.tensor(time) 9 | 10 | frame = time * frame_hz 11 | 12 | if isinstance(frame, torch.Tensor): 13 | frame = frame.long().tolist() 14 | else: 15 | frame = int(frame) 16 | 17 | return frame 18 | 19 | 20 | def frame2time(f, frame_time): 21 | return f * frame_time 22 | 23 | 24 | def time2frames(t, hop_time): 25 | return int(t / hop_time) 26 | 27 | 28 | def find_island_idx_len(x): 29 | """ 30 | Finds patches of the same value. 31 | 32 | starts_idx, duration, values = find_island_idx_len(x) 33 | 34 | e.g: 35 | ends = starts_idx + duration 36 | 37 | s_n = starts_idx[values==n] 38 | ends_n = s_n + duration[values==n] # find all patches with N value 39 | 40 | """ 41 | assert x.ndim == 1 42 | n = len(x) 43 | y = x[1:] != x[:-1] # pairwise unequal (string safe) 44 | i = torch.cat( 45 | (torch.where(y)[0], torch.tensor(n - 1, device=x.device).unsqueeze(0)) 46 | ).long() 47 | it = torch.cat((torch.tensor(-1, device=x.device).unsqueeze(0), i)) 48 | dur = it[1:] - it[:-1] 49 | idx = torch.cumsum( 50 | torch.cat((torch.tensor([0], device=x.device, dtype=torch.long), dur)), dim=0 51 | )[ 52 | :-1 53 | ] # positions 54 | return idx, dur, x[i] 55 | 56 | 57 | def find_label_match(source_idx, target_idx): 58 | match = torch.where(source_idx.unsqueeze(-1) == target_idx) 59 | midx = target_idx[match[-1]] # back to original idx 60 | frames = torch.zeros_like(source_idx) 61 | # Does not work on gpu: frames[match[:-1]] = 1.0 62 | frames[match[:-1]] = torch.ones_like(match[0]) 63 | return frames, midx 64 | 65 | 66 | def get_dialog_states(vad) -> torch.Tensor: 67 | """Vad to the full state of a 2 person vad dialog 68 | 0: only speaker 0 69 | 1: none 70 | 2: both 71 | 3: only speaker 1 72 | """ 73 | assert vad.ndim >= 1 74 | return (2 * vad[..., 1] - vad[..., 0]).long() + 1 75 | 76 | 77 | def last_speaker_single(s): 78 | start, _, val = find_island_idx_len(s) 79 | 80 | # exlude silences (does not effect last_speaker) 81 | # silences should be the value of the previous speaker 82 | sil_idx = torch.where(val == 1)[0] 83 | if len(sil_idx) > 0: 84 | if sil_idx[0] == 0: 85 | val[0] = 2 # 2 is both we don't know if its a shift or hold 86 | sil_idx = sil_idx[1:] 87 | val[sil_idx] = val[sil_idx - 1] 88 | # map speaker B state (=3) to 1 89 | val[val == 3] = 1 90 | # get repetition lengths 91 | repeat = start[1:] - start[:-1] 92 | # Find difference between original and repeated 93 | # and use diff to repeat the last speaker until the end of segment 94 | diff = len(s) - repeat.sum(0) 95 | repeat = torch.cat((repeat, diff.unsqueeze(0))) 96 | # repeat values to create last speaker over entire segment 97 | last_speaker = torch.repeat_interleave(val, repeat) 98 | return last_speaker 99 | 100 | 101 | def get_last_speaker(vad, ds): 102 | assert vad.ndim > 1, "must provide vad of size: (N, channels) or (B, N, channels)" 103 | 104 | # get last active speaker (for turn shift/hold) 105 | if vad.ndim < 3: 106 | last_speaker = last_speaker_single(ds) 107 | else: # (B, N, Channels) = (B, N, n_speakers) 108 | last_speaker = [] 109 | for b in range(vad.shape[0]): 110 | s = ds[b] 111 | last_speaker.append(last_speaker_single(s)) 112 | last_speaker = torch.stack(last_speaker) 113 | return last_speaker 114 | 115 | 116 | def vad_list_to_onehot(vad_list, hop_time, duration, channel_last=False): 117 | n_frames = time2frames(duration, hop_time) + 1 118 | 119 | if isinstance(vad_list[0][0], list): 120 | vad_tensor = torch.zeros((len(vad_list), n_frames)) 121 | for ch, ch_vad in enumerate(vad_list): 122 | for v in ch_vad: 123 | s = time2frames(v[0], hop_time) 124 | e = time2frames(v[1], hop_time) 125 | vad_tensor[ch, s:e] = 1.0 126 | else: 127 | vad_tensor = torch.zeros((1, n_frames)) 128 | for v in vad_list: 129 | s = time2frames(v[0], hop_time) 130 | e = time2frames(v[1], hop_time) 131 | vad_tensor[:, s:e] = 1.0 132 | 133 | if channel_last: 134 | vad_tensor = vad_tensor.permute(1, 0) 135 | 136 | return vad_tensor 137 | 138 | 139 | def vad_to_dialog_vad_states(vad) -> torch.Tensor: 140 | """Vad to the full state of a 2 person vad dialog 141 | 0: only speaker 0 142 | 1: none 143 | 2: both 144 | 3: only speaker 1 145 | """ 146 | assert vad.ndim >= 1 147 | return (2 * vad[..., 1] - vad[..., 0]).long() + 1 148 | 149 | 150 | def mutual_silences(vad, ds=None): 151 | if ds is None: 152 | ds = vad_to_dialog_vad_states(vad) 153 | return ds == 1 154 | 155 | 156 | def get_current_vad_onehot(vad, end, duration, speaker, frame_size): 157 | """frame_size in seconds""" 158 | start = end - duration 159 | n_frames = int(duration / frame_size) 160 | vad_oh = torch.zeros((2, n_frames)) 161 | 162 | for ch, ch_vad in enumerate(vad): 163 | for s, e in ch_vad: 164 | if start <= s <= end: 165 | rel_start = s - start 166 | v_start_frame = round(rel_start / frame_size) 167 | if start <= e <= end: # vad segment completely in chunk 168 | rel_end = e - start 169 | v_end_frame = round(rel_end / frame_size) 170 | vad_oh[ch, v_start_frame : v_end_frame + 1] = 1.0 171 | else: # only start in chunk -> fill until end 172 | vad_oh[ch, v_start_frame:] = 1.0 173 | elif start <= e <= end: # only end in chunk 174 | rel_end = e - start 175 | v_end_frame = round(rel_end / frame_size) 176 | vad_oh[ch, : v_end_frame + 1] = 1.0 177 | elif s > end: 178 | break 179 | 180 | # current speaker is always channel 0 181 | if speaker == 1: 182 | vad_oh = torch.stack((vad_oh[1], vad_oh[0])) 183 | 184 | return vad_oh 185 | 186 | 187 | def get_next_speaker(vad, ds): 188 | """Doing `get_next_speaker` in reverse""" 189 | # Reverse Vad 190 | vad_reversed = vad.flip(dims=(1,)) 191 | ds_reversed = ds.flip(dims=(1,)) 192 | # get "last speaker" 193 | next_speaker = get_last_speaker(vad_reversed, ds_reversed) 194 | # reverse back 195 | next_speaker = next_speaker.flip(dims=(1,)) 196 | return next_speaker 197 | 198 | 199 | def get_hold_shift_onehot(vad): 200 | ds = vad_to_dialog_vad_states(vad) 201 | prev_speaker = get_last_speaker(vad, ds) 202 | next_speaker = get_next_speaker(vad, ds) 203 | silence_ids = torch.where(vad.sum(-1) == 0) 204 | 205 | hold_one_hot = torch.zeros_like(prev_speaker) 206 | shift_one_hot = torch.zeros_like(prev_speaker) 207 | 208 | hold = prev_speaker[silence_ids] == next_speaker[silence_ids] 209 | hold_one_hot[silence_ids] = hold.long() 210 | shift_one_hot[silence_ids] = torch.logical_not(hold).long() 211 | return hold_one_hot, shift_one_hot 212 | 213 | 214 | # vad context history 215 | def get_vad_condensed_history(vad, t, speaker, bin_end_times=[60, 30, 15, 5, 0]): 216 | """ 217 | get the vad-condensed-history over the history of the dialog. 218 | 219 | the amount of active seconds are calculated for each speaker in the segments defined by `bin_end_times` 220 | (starting from 0). 221 | The idea is to represent the past further away from the current moment in time more granularly. 222 | 223 | for example: 224 | bin_end_times=[60, 30, 10, 5, 0] extracts activity for each speaker in the intervals: 225 | 226 | [-inf, t-60] 227 | [t-60, t-30] 228 | [t-30, t-10] 229 | [t-10, t-5] 230 | [t-50, t] 231 | 232 | The final representation is then the ratio of activity for the 233 | relevant `speaker` over the total activity, for each bin. if there 234 | is no activity, that is the segments span before the dialog started 235 | or (unusually) both are silent, then we set the ratio to 0.5, to 236 | indicate equal participation. 237 | 238 | Argument: 239 | - vad: list: [[(0, 3), (4, 6), ...], [...]] list of list of channel start and end time 240 | """ 241 | n_bins = len(bin_end_times) 242 | T = t - torch.tensor(bin_end_times) 243 | bin_times = [0] + T.tolist() 244 | 245 | bins = torch.zeros(2, n_bins) 246 | for ch, ch_vad in enumerate(vad): # iterate over each channel 247 | s = bin_times[0] 248 | for i, e in enumerate(bin_times[1:]): # iterate over bin segments 249 | if e < 0: # skip if before dialog start 250 | s = e # update 251 | continue 252 | for vs, ve in ch_vad: # iterate over channel VAD 253 | if vs >= s: # start inside bin time 254 | if vs < e and ve <= e: # both vad_start/end occurs in segment 255 | bins[ch][i] += ve - vs 256 | elif vs < e: # only start occurs in segment 257 | bins[ch][i] += e - vs 258 | elif ( 259 | vs > e 260 | ): # all starts occus after bin-end -> no need to process further 261 | break 262 | else: # vs is before segment 263 | if s <= ve <= e: # ending occurs in segment 264 | bins[ch][i] += ve - s 265 | # update bin start 266 | s = e 267 | # Avoid nan -> for loop 268 | # get the ratio of the relevant speaker 269 | # if there is no information (bins are before dialog start) we use an equal prior (=.5) 270 | ratios = torch.zeros(n_bins) 271 | for b in range(n_bins): 272 | binsum = bins[:, b].sum() 273 | if binsum > 0: 274 | ratios[b] = bins[speaker, b] / binsum 275 | else: 276 | ratios[b] = 0.5 # equal prior for segments with no one speaking 277 | return ratios 278 | 279 | 280 | @torch.no_grad() 281 | def get_activity_history(vad_frames, bin_end_frames, channel_last=True): 282 | """ 283 | 284 | Uses convolutions to sum the activity over each segment of interest. 285 | 286 | The kernel size is set to be the number of frames of any particular segment i.e. 287 | 288 | --------------------------------------------------- 289 | 290 | 291 | ``` 292 | ... h0 | h1 | h2 | h3 | h4 + 293 | distant past | | | | + 294 | -inf -> -t0 | | | | + 295 | 296 | ``` 297 | 298 | --------------------------------------------------- 299 | 300 | Arguments: 301 | vad_frames: torch.tensor: (Channels, N_Frames) or (N_Frames, Channels) 302 | bin_end_frames: list: boundaries for the activity history windows i.e. [6000, 3000, 1000, 500] 303 | channel_last: bool: if true we expect `vad_frames` to be (N_Frames, Channels) 304 | 305 | Returns: 306 | ratios: torch.tensor: (Channels, N_frames, bins) or (N_frames, bins, Channels) (dependent on `channel_last`) 307 | history_bins: torch.tesnor: same size as ratio but contains the number of active frames, over each segment, for both speakers. 308 | """ 309 | 310 | N = vad_frames.shape[0] 311 | if channel_last: 312 | vad_frames = rearrange(vad_frames, "n c -> c n") 313 | 314 | # container for the activity of the defined bins 315 | hist_bins = [] 316 | 317 | # Distance past activity history/ratio 318 | # The segment from negative infinity to the first bin_end_frames 319 | if vad_frames.shape[0] > bin_end_frames[0]: 320 | h0 = vad_frames[:, : -bin_end_frames[0]].cumsum(dim=-1) 321 | diff_pad = torch.ones(2, bin_end_frames[0]) * -1 322 | h0 = torch.cat((diff_pad, h0), dim=-1) 323 | else: 324 | # there is not enough duration to get any long time information 325 | # -> set to prior of equal speech 326 | # negative values for debugging to see where we provide prior 327 | # (not seen outside of this after r0/r1 further down) 328 | h0 = torch.ones(2, N) * -1 329 | hist_bins.append(h0) 330 | 331 | # Activity of segments defined by the the `bin_end_frames` 332 | 333 | # If 0 is not included in the window (i.e. the current frame) 334 | # we append it for consistency in loop below 335 | if bin_end_frames[-1] != 0: 336 | bin_end_frames = bin_end_frames + [0] 337 | 338 | # Loop over each segment window, construct conv1d (summation: all weights are 1.) 339 | # Omit end-frames which are not used for the current bin 340 | # concatenate activity sum with pad (= -1) at the start where the bin values are 341 | # not defined. 342 | for start, end in zip(bin_end_frames[:-1], bin_end_frames[1:]): 343 | ks = start - end 344 | if end > 0: 345 | vf = vad_frames[:, :-end] 346 | else: 347 | vf = vad_frames 348 | if vf.shape[1] > 0: 349 | filters = torch.ones((1, 1, ks), dtype=torch.float) 350 | vf = F.pad(vf, [ks - 1, 0]).unsqueeze(1) # add channel dim 351 | o = F.conv1d(vf, weight=filters).squeeze(1) # remove channel dim 352 | if end > 0: 353 | # print('diffpad: ', end) 354 | diff_pad = torch.ones(2, end) * -1 355 | o = torch.cat((diff_pad, o), dim=-1) 356 | else: 357 | # there is not enough duration to get any long time information 358 | # -> set to prior of equal speech 359 | # negative values for debugging to see where we provide prior 360 | # (not seen outside of this after r0/r1 further down) 361 | o = torch.ones(2, N) * -1 362 | hist_bins.append(o) 363 | 364 | # stack together -> (2, N, len(bin_end_frames) + 1) default: (2, N, 5) 365 | hist_bins = torch.stack(hist_bins, dim=-1) 366 | 367 | # find the ratios for each speaker 368 | r0 = hist_bins[0] / hist_bins.sum(dim=0) 369 | r1 = hist_bins[1] / hist_bins.sum(dim=0) 370 | 371 | # segments where both speakers are silent (i.e. [0, 0] activation) 372 | # are not defined (i.e. hist_bins / hist_bins.sum = 0 / 0 ). 373 | # Where both speakers are silent they have equal amount of 374 | nan_inds = torch.where(r0.isnan()) 375 | r0[nan_inds] = 0.5 376 | r1[nan_inds] = 0.5 377 | 378 | # Consistent input/output with `channel_last` VAD 379 | if channel_last: 380 | ratio = torch.stack((r0, r1), dim=-1) 381 | else: 382 | ratio = torch.stack((r0, r1)) 383 | return ratio, hist_bins 384 | -------------------------------------------------------------------------------- /vap_turn_taking/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric, F1Score, PrecisionRecallCurve, StatScores 3 | 4 | from vap_turn_taking.events import TurnTakingEvents 5 | 6 | 7 | class F1_Hold_Shift(Metric): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.stat_scores = StatScores(reduce="macro", multiclass=True, num_classes=2) 11 | 12 | def probs_shift_hold(self, p, shift, hold): 13 | probs, labels = [], [] 14 | 15 | for next_speaker in [0, 1]: 16 | ws = torch.where(shift[..., next_speaker]) 17 | if len(ws[0]) > 0: 18 | tmp_probs = p[ws][..., next_speaker] 19 | tmp_lab = torch.ones_like(tmp_probs, dtype=torch.long) 20 | probs.append(tmp_probs) 21 | labels.append(tmp_lab) 22 | 23 | # Hold label -> 0 24 | # Hold prob -> 1 - p # opposite guess 25 | wh = torch.where(hold[..., next_speaker]) 26 | if len(wh[0]) > 0: 27 | # complement in order to be combined with shifts 28 | tmp_probs = 1 - p[wh][..., next_speaker] 29 | tmp_lab = torch.zeros_like(tmp_probs, dtype=torch.long) 30 | probs.append(tmp_probs) 31 | labels.append(tmp_lab) 32 | 33 | if len(probs) > 0: 34 | probs = torch.cat(probs) 35 | labels = torch.cat(labels) 36 | else: 37 | probs = None 38 | labels = None 39 | return probs, labels 40 | 41 | def get_score(self, tp, fp, tn, fn, EPS=1e-9): 42 | precision = tp / (tp + fp + EPS) 43 | recall = tp / (tp + fn + EPS) 44 | f1 = tp / (tp + 0.5 * (fp + fn) + EPS) 45 | return f1, precision, recall 46 | 47 | def reset(self): 48 | self.stat_scores.reset() 49 | 50 | def compute(self): 51 | hold, shift = self.stat_scores.compute() 52 | 53 | # HOLD 54 | h_tp, h_fp, h_tn, h_fn, h_sup = hold 55 | h_f1, h_precision, h_recall = self.get_score(h_tp, h_fp, h_tn, h_fn) 56 | 57 | # SHIFT 58 | s_tp, s_fp, s_tn, s_fn, s_sup = shift 59 | s_f1, s_precision, s_recall = self.get_score(s_tp, s_fp, s_tn, s_fn) 60 | 61 | # Weighted F1 62 | f1h = h_f1 * h_sup 63 | f1s = s_f1 * s_sup 64 | tot = h_sup + s_sup 65 | f1_weighted = (f1h + f1s) / tot 66 | return { 67 | "f1_weighted": f1_weighted, 68 | "hold": { 69 | "f1": h_f1, 70 | "precision": h_precision, 71 | "recall": h_recall, 72 | "support": h_sup, 73 | }, 74 | "shift": { 75 | "f1": s_f1, 76 | "precision": s_precision, 77 | "recall": s_recall, 78 | "support": s_sup, 79 | }, 80 | } 81 | 82 | def update(self, p, hold, shift): 83 | probs, labels = self.probs_shift_hold(p, shift=shift, hold=hold) 84 | if probs is not None: 85 | self.stat_scores.update(probs, labels) 86 | 87 | 88 | class TurnTakingMetrics(Metric): 89 | """ 90 | Used with discrete model, VAProjection. 91 | """ 92 | 93 | def __init__( 94 | self, 95 | hs_kwargs, 96 | bc_kwargs, 97 | metric_kwargs, 98 | threshold_pred_shift=0.5, 99 | threshold_short_long=0.5, 100 | threshold_bc_pred=0.5, 101 | bc_pred_pr_curve=False, 102 | shift_pred_pr_curve=False, 103 | long_short_pr_curve=False, 104 | frame_hz=100, 105 | dist_sync_on_step=False, 106 | ): 107 | # call `self.add_state`for every internal state that is needed for the metrics computations 108 | # dist_reduce_fx indicates the function that should be used to reduce 109 | # state from multiple processes 110 | super().__init__(dist_sync_on_step=dist_sync_on_step) 111 | 112 | # Metrics 113 | # self.f1: class to provide f1-weighted as well as other stats tp,fp,support, etc... 114 | self.hs = F1_Hold_Shift() 115 | self.predict_shift = F1Score( 116 | threshold=threshold_pred_shift, 117 | num_classes=2, 118 | multiclass=True, 119 | average="weighted", 120 | ) 121 | self.short_long = F1Score( 122 | threshold=threshold_short_long, 123 | num_classes=2, 124 | multiclass=True, 125 | average="weighted", 126 | ) 127 | self.predict_backchannel = F1Score( 128 | threshold=threshold_bc_pred, 129 | num_classes=2, 130 | multiclass=True, 131 | average="weighted", 132 | ) 133 | 134 | self.pr_curve_bc_pred = bc_pred_pr_curve 135 | if self.pr_curve_bc_pred: 136 | self.bc_pred_pr = PrecisionRecallCurve(pos_label=1) 137 | 138 | self.pr_curve_shift_pred = shift_pred_pr_curve 139 | if self.pr_curve_shift_pred: 140 | self.shift_pred_pr = PrecisionRecallCurve(pos_label=1) 141 | 142 | self.pr_curve_long_short = long_short_pr_curve 143 | if self.pr_curve_long_short: 144 | self.long_short_pr = PrecisionRecallCurve(pos_label=1) 145 | 146 | # Extract the frames of interest for the given metrics 147 | self.eventer = TurnTakingEvents( 148 | hs_kwargs=hs_kwargs, 149 | bc_kwargs=bc_kwargs, 150 | metric_kwargs=metric_kwargs, 151 | frame_hz=frame_hz, 152 | ) 153 | 154 | @torch.no_grad() 155 | def extract_events(self, va, max_frame=None): 156 | return self.eventer(va, max_frame=max_frame) 157 | 158 | def __repr__(self): 159 | s = "TurnTakingMetrics" 160 | s += self.eventer.__repr__() 161 | return s 162 | 163 | def update_short_long(self, p, short, long): 164 | """ 165 | The given speaker in short/long is the one who initiated an onset. 166 | 167 | Use the backchannel (prediction) prob to recognize short utterance. 168 | 169 | event -> label 170 | short -> 1 171 | long -> 0 172 | """ 173 | 174 | probs, labels = [], [] 175 | 176 | # At the onset of a SHORT utterance the probability associated 177 | # with that person being the next speaker should be low -> 0 178 | if short.sum() > 0: 179 | w = torch.where(short) 180 | p_short = p[w] 181 | probs.append(p_short) 182 | # labels.append(torch.zeros_like(p_short)) 183 | labels.append(torch.ones_like(p_short)) 184 | 185 | # At the onset of a LONG utterance the probability associated 186 | # with that person being the next speaker should be high -> 1 187 | if long.sum() > 0: 188 | w = torch.where(long) 189 | p_long = p[w] 190 | probs.append(p_long) 191 | # labels.append(torch.ones_like(p_long)) 192 | labels.append(torch.zeros_like(p_long)) 193 | 194 | if len(probs) > 0: 195 | probs = torch.cat(probs) 196 | labels = torch.cat(labels).long() 197 | self.short_long.update(probs, labels) 198 | 199 | if self.pr_curve_long_short: 200 | self.long_short_pr.update(probs, labels) 201 | 202 | def update_predict_shift(self, p, pos, neg): 203 | """ 204 | Predict upcomming speaker shift. The events pos/neg are given for the 205 | correct next speaker. 206 | 207 | correct classifications 208 | * pos next_speaker -> 1 209 | * neg next_speaker -> 1 210 | 211 | so we flip the negatives to have label 0 and take 1-p as their associated predictions 212 | 213 | """ 214 | probs, labels = [], [] 215 | 216 | # At the onset of a SHORT utterance the probability associated 217 | # with that person being the next speaker should be low -> 0 218 | if pos.sum() > 0: 219 | w = torch.where(pos) 220 | p_pos = p[w] 221 | probs.append(p_pos) 222 | labels.append(torch.ones_like(p_pos)) 223 | 224 | # At the onset of a LONG utterance the probability associated 225 | # with that person being the next speaker should be high -> 1 226 | if neg.sum() > 0: 227 | w = torch.where(neg) 228 | p_neg = 1 - p[w] # reverse to make negatives have label 0 229 | probs.append(p_neg) 230 | labels.append(torch.zeros_like(p_neg)) 231 | 232 | if len(probs) > 0: 233 | probs = torch.cat(probs) 234 | labels = torch.cat(labels).long() 235 | self.predict_shift.update(probs, labels) 236 | 237 | if self.pr_curve_shift_pred: 238 | self.shift_pred_pr.update(probs, labels) 239 | 240 | def update_predict_backchannel(self, bc_pred_probs, pos, neg): 241 | """ 242 | bc_pred_probs contains the probabilities associated with the given speaker 243 | initiating a backchannel in the "foreseeble" future. 244 | 245 | At POSITIVE events the speaker resposible for the actual upcomming backchannel 246 | is the same as the speaker in the event. 247 | 248 | At NEGATIVE events the speaker that "could have been" responsible for the upcomming backchennel 249 | is THE OTHER speaker so the probabilities much be switched. 250 | The probabilties associated with predicting THE OTHER is goin to say a backchannel is wrong so we 251 | flip the probabilities such that they should be close to 0. 252 | 253 | """ 254 | probs, labels = [], [] 255 | 256 | if pos.sum() > 0: 257 | w = torch.where(pos) 258 | p_pos = bc_pred_probs[w] 259 | probs.append(p_pos) 260 | labels.append(torch.ones_like(p_pos)) 261 | 262 | if neg.sum() > 0: 263 | # where is negative samples? 264 | wb, wn, w_speaker = torch.where(neg) 265 | w_backchanneler = torch.logical_not(w_speaker).long() 266 | 267 | # p_neg = 1 - bc_pred_probs[(wb, wn, w_backchanneler)] 268 | p_neg = bc_pred_probs[(wb, wn, w_backchanneler)] 269 | probs.append(p_neg) 270 | labels.append(torch.zeros_like(p_neg)) 271 | 272 | if len(probs) > 0: 273 | probs = torch.cat(probs) 274 | labels = torch.cat(labels).long() 275 | self.predict_backchannel(probs, labels) 276 | 277 | if self.pr_curve_bc_pred: 278 | self.bc_pred_pr.update(probs, labels) 279 | 280 | def reset(self): 281 | super().reset() 282 | self.hs.reset() 283 | self.predict_shift.reset() 284 | self.short_long.reset() 285 | self.predict_backchannel.reset() 286 | if self.pr_curve_bc_pred: 287 | self.bc_pred_pr.reset() 288 | 289 | if self.pr_curve_shift_pred: 290 | self.shift_pred_pr.reset() 291 | 292 | if self.pr_curve_long_short: 293 | self.long_short_pr.reset() 294 | 295 | def compute(self): 296 | f1_hs = self.hs.compute() 297 | f1_predict_shift = self.predict_shift.compute() 298 | f1_short_long = self.short_long.compute() 299 | 300 | ret = { 301 | "f1_hold_shift": f1_hs["f1_weighted"], 302 | "f1_predict_shift": f1_predict_shift, 303 | "f1_short_long": f1_short_long, 304 | } 305 | 306 | try: 307 | ret["f1_bc_prediction"] = self.predict_backchannel.compute() 308 | except: 309 | ret["f1_bc_prediction"] = -1 310 | 311 | if self.pr_curve_bc_pred: 312 | ret["pr_curve_bc_pred"] = self.bc_pred_pr.compute() 313 | 314 | if self.pr_curve_shift_pred: 315 | ret["pr_curve_shift_pred"] = self.shift_pred_pr.compute() 316 | 317 | if self.pr_curve_long_short: 318 | ret["pr_curve_long_short"] = self.long_short_pr.compute() 319 | 320 | ret["shift"] = f1_hs["shift"] 321 | ret["hold"] = f1_hs["hold"] 322 | return ret 323 | 324 | def update(self, p, bc_pred_probs=None, events=None, va=None): 325 | """ 326 | p: tensor, next_speaker probability. Must take into account current speaker such that it can be used for pre-shift/hold, backchannel-pred/ongoing 327 | pre_probs: tensor, on active next speaker probability for independent 328 | bc_pred_probs: tensor, Special probability associated with a backchannel prediction 329 | events: dict, containing information about the events in the sequences 330 | vad: tensor, VAD activity. Only used if events is not given. 331 | 332 | 333 | events: [ 334 | 'shift', 335 | 'hold', 336 | 'short', 337 | 'long', 338 | 'predict_shift_pos', 339 | 'predict_shift_neg', 340 | 'predict_bc_pos', 341 | 'predict_bc_neg' 342 | ] 343 | """ 344 | 345 | # Find valid event-frames if event is not given 346 | if events is None: 347 | events = self.extract_events(va) 348 | 349 | # SHIFT/HOLD 350 | self.hs.update(p, hold=events["hold"], shift=events["shift"]) 351 | 352 | # Predict Shifts 353 | self.update_predict_shift( 354 | p, pos=events["predict_shift_pos"], neg=events["predict_shift_neg"] 355 | ) 356 | 357 | # PREDICT BACKCHANNELS & Short/Long 358 | if bc_pred_probs is not None: 359 | self.update_predict_backchannel( 360 | bc_pred_probs, 361 | pos=events["predict_bc_pos"], 362 | neg=events["predict_bc_neg"], 363 | ) 364 | 365 | # Long/Short 366 | self.update_short_long( 367 | bc_pred_probs, short=events["short"], long=events["long"] 368 | ) 369 | else: 370 | # Long/Short 371 | self.update_short_long(p, short=events["short"], long=events["long"]) 372 | 373 | 374 | def main_old(): 375 | from tqdm import tqdm 376 | import matplotlib.pyplot as plt 377 | from conv_ssl.evaluation.utils import load_dm, load_model 378 | from conv_ssl.utils import to_device 379 | 380 | # Load Data 381 | # The only required data is VAD (onehot encoding of voice activity) e.g. (B, N_FRAMES, 2) for two speakers 382 | dm = load_dm(batch_size=12) 383 | # diter = iter(dm.val_dataloader()) 384 | 385 | ################################################### 386 | # Load Model 387 | ################################################### 388 | # run_path = "how_so/VPModel/10krujrj" # independent 389 | run_path = "how_so/VPModel/sbzhz86n" # discrete 390 | # run_path = "how_so/VPModel/2608x2g0" # independent (same bin size) 391 | model = load_model(run_path=run_path, strict=False) 392 | model = model.eval() 393 | # model = model.to("cpu") 394 | # model = model.to("cpu") 395 | 396 | event_kwargs = dict( 397 | shift_onset_cond=1, 398 | shift_offset_cond=1, 399 | hold_onset_cond=1, 400 | hold_offset_cond=1, 401 | min_silence=0.15, 402 | non_shift_horizon=2.0, 403 | non_shift_majority_ratio=0.95, 404 | metric_pad=0.05, 405 | metric_dur=0.1, 406 | metric_onset_dur=0.3, 407 | metric_pre_label_dur=0.5, 408 | metric_min_context=1.0, 409 | bc_max_duration=1.0, 410 | bc_pre_silence=1.0, 411 | bc_post_silence=3.0, 412 | ) 413 | 414 | # # update vad_projection metrics 415 | # metric_kwargs = { 416 | # "event_pre": 0.5, # seconds used to estimate PRE-f1-SHIFT/HOLD 417 | # "event_min_context": 1.0, # min context duration before extracting metrics 418 | # "event_min_duration": 0.15, # the minimum required segment to extract SHIFT/HOLD (start_pad+target_duration) 419 | # "event_horizon": 1.0, # SHIFT/HOLD requires lookahead to determine mutual starts etc 420 | # "event_start_pad": 0.05, # Predict SHIFT/HOLD after this many seconds of silence after last speaker 421 | # "event_target_duration": 0.10, # duration of segment to extract each SHIFT/HOLD guess 422 | # "event_bc_target_duration": 0.25, # duration of activity, in a backchannel, to extract BC-ONGOING metrics 423 | # "event_bc_pre_silence": 1, # du 424 | # "event_bc_post_silence": 2, 425 | # "event_bc_max_active": 1.0, 426 | # "event_bc_prediction_window": 0.4, 427 | # "event_bc_neg_active": 1, 428 | # "event_bc_neg_prefix": 1, 429 | # "event_bc_ongoing_threshold": 0.5, 430 | # "event_bc_pred_threshold": 0.5, 431 | # } 432 | # # Updatemetric_kwargs metrics 433 | # for metric, val in metric_kwargs.items(): 434 | # model.conf["vad_projection"][metric] = val 435 | 436 | N = 10 437 | model.test_metric = model.init_metric( 438 | model.conf, model.frame_hz, bc_pred_pr_curve=False, **event_kwargs 439 | ) 440 | # tt_metrics = TurnTakingMetricsDiscrete(bin_times=model.conf['vad_projection']['bin_times']) 441 | for ii, batch in tqdm(enumerate(dm.val_dataloader()), total=N): 442 | batch = to_device(batch, model.device) 443 | ######################################################################## 444 | # Extract events/labels on full length (with horizon) VAD 445 | events = model.test_metric.extract_events(batch["vad"], max_frame=1000) 446 | ######################################################################## 447 | # Forward Pass through the model 448 | loss, out, batch = model.shared_step(batch) 449 | turn_taking_probs = model.get_next_speaker_probs( 450 | out["logits_vp"], vad=batch["vad"] 451 | ) 452 | ######################################################################## 453 | # Update metrics 454 | model.test_metric.update( 455 | p=turn_taking_probs["p"], 456 | pw=turn_taking_probs.get("pw", None), 457 | pre_probs=turn_taking_probs.get("pre_probs", None), 458 | bc_pred_probs=turn_taking_probs.get("bc_prediction", None), 459 | events=events, 460 | ) 461 | if ii == N: 462 | break 463 | result = model.test_metric.compute() 464 | print(result.keys()) 465 | 466 | for k, v in result.items(): 467 | print(f"{k}: {v}") 468 | 469 | 470 | if __name__ == "__main__": 471 | from vap_turn_taking.config.example_data import example, event_conf 472 | 473 | metric = TurnTakingMetrics( 474 | hs_kwargs=event_conf["hs"], 475 | bc_kwargs=event_conf["bc"], 476 | metric_kwargs=event_conf["metric"], 477 | bc_pred_pr_curve=True, 478 | shift_pred_pr_curve=True, 479 | long_short_pr_curve=True, 480 | frame_hz=100, 481 | ) 482 | 483 | # Update 484 | metric.update( 485 | p=turn_taking_probs["p"], 486 | bc_pred_probs=turn_taking_probs.get("bc_prediction", None), 487 | events=events, 488 | ) 489 | 490 | # Compute 491 | result = metric.compute() 492 | -------------------------------------------------------------------------------- /vap_turn_taking/hold_shifts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vap_turn_taking.utils import ( 3 | find_island_idx_len, 4 | get_dialog_states, 5 | get_last_speaker, 6 | ) 7 | 8 | 9 | class HoldShift: 10 | """ 11 | Hold/Shift extraction from VAD. Operates of Frames. 12 | 13 | Arguments: 14 | post_onset_shift: int, frames for shift onset cond 15 | pre_offset_shift: int, frames for shift offset cond 16 | post_onset_hold: int, frames for hold onset cond 17 | pre_offset_hold: int, frames for hold offset cond 18 | metric_pad: int, pad on silence (shift/hold) onset used for evaluating 19 | metric_dur: int, duration off silence (shift/hold) used for evaluating 20 | metric_pre_label_dur: int, frames prior to Shift-silence for prediction on-active shift 21 | non_shift_horizon: int, frames to define majority speaker window for Non-shift 22 | non_shift_majority_ratio: float, ratio of majority speaker 23 | 24 | Return: 25 | dict: {'shift', 'pre_shift', 'hold', 'pre_hold', 'non_shift'} 26 | 27 | Active: "---" 28 | Silent: "..." 29 | 30 | # SHIFTS 31 | 32 | onset: |<-- only A -->| 33 | A: ...........................|------------------- 34 | B: ----------------|.............................. 35 | offset: |<-- only B -->| 36 | SHIFT: |XXXXXXXXXX| 37 | 38 | ----------------------------------------------------------- 39 | # HOLDS 40 | 41 | onset: |<-- only B -->| 42 | A: ............................................... 43 | B: ----------------|..........|------------------- 44 | offset: |<-- only B -->| 45 | HOLD: |XXXXXXXXXX| 46 | 47 | ----------------------------------------------------------- 48 | # NON-SHIFT 49 | 50 | Horizon: |<-- B majority -->| 51 | A: .....................................|--------- 52 | B: ----------------|......|------|................ 53 | non_shift: |XXXXXXXXXXXXXXXXXXX| 54 | 55 | A future horizon window must contain 'majority' activity from 56 | from the last speaker. In these moments we "know" a shift 57 | is a WRONG prediction. But closer to activity from the 'other' 58 | speaker, a turn-shift is appropriate. 59 | 60 | ----------------------------------------------------------- 61 | # metrics 62 | e.g. shift 63 | 64 | onset: |<-- only A -->| 65 | A: ...............................|--------------- 66 | B: ----------------|.............................. 67 | offset: |<-- only B -->| 68 | SHIFT: |XXXXXXXXXXXXXX| 69 | metric: |...|XXXXXX| 70 | metric: |pad| dur | 71 | 72 | ----------------------------------------------------------- 73 | 74 | 75 | Using 'dialog states' consisting of 4 different states 76 | 0. Only A is speaking 77 | 1. Silence 78 | 2. Overlap 79 | 3. Only B is speaking 80 | 81 | Shift GAP: 0 -> 1 -> 3 3 -> 1 -> 0 82 | Shift Overlap: 0 -> 2 -> 3 3 -> 2 -> 0 83 | HOLD: 0 -> 1 -> 0 3 -> 1 -> 3 84 | """ 85 | 86 | def __init__( 87 | self, 88 | post_onset_shift, 89 | pre_offset_shift, 90 | post_onset_hold, 91 | pre_offset_hold, 92 | non_shift_horizon, 93 | metric_pad, 94 | metric_dur, 95 | metric_pre_label_dur, 96 | metric_onset_dur, 97 | non_shift_majority_ratio=1, 98 | ): 99 | assert ( 100 | metric_onset_dur <= post_onset_shift 101 | ), "`metric_onset_dur` must be less or equal to `post_onset_shift`" 102 | 103 | self.post_onset_shift = post_onset_shift 104 | self.pre_offset_shift = pre_offset_shift 105 | self.post_onset_hold = post_onset_hold 106 | self.pre_offset_hold = pre_offset_hold 107 | 108 | self.metric_pad = metric_pad 109 | self.metric_dur = metric_dur 110 | self.min_silence = metric_pad + metric_dur 111 | self.metric_pre_label_dur = metric_pre_label_dur 112 | self.metric_onset_dur = metric_onset_dur 113 | 114 | self.non_shift_horizon = non_shift_horizon 115 | self.non_shift_majority_ratio = non_shift_majority_ratio 116 | 117 | # Templates 118 | self.shift_template = torch.tensor([[3, 1, 0], [0, 1, 3]]) # on Silence 119 | self.shift_overlap_template = torch.tensor([[3, 2, 0], [0, 2, 3]]) 120 | self.hold_template = torch.tensor([[0, 1, 0], [3, 1, 3]]) # on silence 121 | 122 | def __repr__(self): 123 | s = "Holds & Shifts" 124 | s += f"\n post_onset_shift: {self.post_onset_shift}" 125 | s += f"\n pre_offset_shift: {self.pre_offset_shift}" 126 | s += f"\n post_onset_hold: {self.post_onset_hold}" 127 | s += f"\n pre_offset_hold: {self.pre_offset_hold}" 128 | s += f"\n min_silence: {self.min_silence}" 129 | s += f"\n metric_pad: {self.metric_pad}" 130 | s += f"\n metric_dur: {self.metric_dur}" 131 | s += f"\n metric_pre_label_dur: {self.metric_pre_label_dur}" 132 | s += f"\n non_shift_horizon: {self.non_shift_horizon}" 133 | s += f"\n non_shift_majority_ratio: {self.non_shift_majority_ratio}" 134 | return s 135 | 136 | def fill_template(self, vad, ds, template): 137 | """ 138 | Used in practice to create VAD -> FILLED_VAD, where filled vad combines 139 | consecutive segments of activity from the same speaker as a single 140 | chunk. 141 | """ 142 | 143 | filled_vad = vad.clone() 144 | for b in range(ds.shape[0]): 145 | s, d, v = find_island_idx_len(ds[b]) 146 | if len(v) < 3: 147 | continue 148 | triads = v.unfold(0, size=3, step=1) 149 | next_speaker, steps = torch.where( 150 | (triads == template.unsqueeze(1)).sum(-1) == 3 151 | ) 152 | for ns, pre in zip(next_speaker, steps): 153 | cur = pre + 1 154 | # Fill the matching template 155 | filled_vad[b, s[cur] : s[cur] + d[cur], ns] = 1.0 156 | return filled_vad 157 | 158 | def match_template( 159 | self, 160 | vad, 161 | ds, 162 | template, 163 | pre_cond_frames, 164 | post_cond_frames, 165 | pre_match=False, 166 | onset_match=False, 167 | max_frame=None, 168 | min_context=0, 169 | ): 170 | """ 171 | Creates a onehot vector where the steps matching the template. 172 | Return: 173 | match_oh: torch.Tensor (B, N, 2), where the last bin corresponds to the next speaker 174 | """ 175 | 176 | hold_cond = template[0, 0] == template[0, -1] 177 | 178 | match_oh = torch.zeros((*ds.shape, 2), device=ds.device, dtype=torch.float) 179 | 180 | pre_match_oh = None 181 | if pre_match: 182 | pre_match_oh = torch.zeros( 183 | (*ds.shape, 2), device=ds.device, dtype=torch.float 184 | ) 185 | 186 | onset_match_oh = None 187 | if onset_match: 188 | onset_match_oh = torch.zeros( 189 | (*ds.shape, 2), device=ds.device, dtype=torch.float 190 | ) 191 | for b in range(ds.shape[0]): 192 | s, d, v = find_island_idx_len(ds[b]) 193 | 194 | if len(v) < 3: 195 | continue 196 | 197 | triads = v.unfold(0, size=3, step=1) 198 | next_speaker, steps = torch.where( 199 | (triads == template.unsqueeze(1)).sum(-1) == 3 200 | ) 201 | 202 | # ns: next_speaker, pre_step 203 | for ns, pre_step in zip(next_speaker, steps): 204 | # If template is of 'HOLD-type' then previous speaker is the 205 | # same as next speaker. Otherwise they are different. 206 | nos = 0 if ns == 1 else 1 # strictly the OTHER 'next speaker' 207 | ps = ns if hold_cond else nos # previous speaker 208 | 209 | cur = pre_step + 1 210 | post = pre_step + 2 211 | 212 | # Silence Condition: if the current step is silent (shift with gap and holds) 213 | # then we only care about silences over a certain duration. 214 | if v[cur] == 1 and d[cur] < self.min_silence: 215 | continue 216 | 217 | # Can this be useful? older way of only considering active segments where 218 | # pauses have not been filled... 219 | # Shifts are more sensible to overall activity around silence/overlap 220 | # and uses `filled_vad` as vad where consecutive 221 | # if vad is None: 222 | # if d[pre_step] >= pre_cond_frames and d[post] >= post_cond_frames: 223 | # match_oh[b, s[cur] : s[cur] + d[cur], ns] = 1.0 224 | # continue 225 | 226 | # pre_condition 227 | # using a filled version of the VAD signal we check wheather 228 | # only the 'previous speaker, ps' was active. This will then include 229 | # activity from that speaker deliminated by silence/pauses/holds 230 | pre_start = s[cur] - pre_cond_frames 231 | 232 | # print('pre_start: ', pre_start, s[cur]) 233 | pre_cond1 = vad[b, pre_start : s[cur], ps].sum() == pre_cond_frames 234 | not_ps = 0 if ps == 1 else 1 235 | pre_cond2 = vad[b, pre_start : s[cur], not_ps].sum() == 0 236 | pre_cond = torch.logical_and(pre_cond1, pre_cond2) 237 | 238 | if not pre_cond: 239 | # pre_cond = vad[b, pre_start : s[cur], ps].sum() 240 | # print("pre cond Failed: ", pre_cond, pre_cond_frames) 241 | # # print(vad[b, pre_start:s[cur]+d[cur]+10]) 242 | # input() 243 | continue 244 | 245 | # single speaker post 246 | post_start = s[post] 247 | post_end = post_start + post_cond_frames 248 | post_cond1 = vad[b, post_start:post_end, ns].sum() == post_cond_frames 249 | post_cond2 = vad[b, post_start:post_end, nos].sum() == 0 250 | post_cond = torch.logical_and(post_cond1, post_cond2) 251 | if not post_cond: 252 | # post_cond = vad[b, post_start:post_end, ns].sum() 253 | # print("post cond Failed: ", post_cond, post_cond_frames) 254 | # print(vad[b, pre_start:s[cur]+d[cur]+10]) 255 | # input() 256 | continue 257 | 258 | # start = s[cur] 259 | # end = s[cur] + d[cur] 260 | # if self.metric_pad > 0: 261 | # start += self.metric_pad 262 | # 263 | # if self.metric_dur > 0: 264 | # end = start + self.metric_dur 265 | 266 | # Max frame condition: 267 | # Can't have event outside of predictable window 268 | if max_frame is not None: 269 | if s[cur] >= max_frame: 270 | continue 271 | 272 | # Min context condition: 273 | if (s[cur] + self.metric_pad) < min_context: 274 | continue 275 | 276 | if pre_match: 277 | pre_match_oh[ 278 | b, s[cur] - self.metric_pre_label_dur : s[cur], ns 279 | ] = 1.0 280 | 281 | # end = s[cur] + self.metric_pad + d[cur] 282 | end = s[cur] + self.metric_pad + self.metric_dur 283 | # Max frame condition: 284 | # Can't have event outside of predictable window 285 | if max_frame is not None: 286 | if end >= max_frame: 287 | continue 288 | 289 | match_oh[b, s[cur] + self.metric_pad : end, ns] = 1.0 290 | 291 | if onset_match: 292 | end = s[post] + self.metric_onset_dur 293 | if max_frame is not None: 294 | if end >= max_frame: 295 | continue 296 | onset_match_oh[b, s[post] : end, ns] = 1.0 297 | 298 | return match_oh, pre_match_oh, onset_match_oh 299 | 300 | def non_shifts( 301 | self, 302 | vad, 303 | last_speaker, 304 | horizon, 305 | majority_ratio=1, 306 | max_frame=None, 307 | min_context=0, 308 | ): 309 | """ 310 | 311 | Non-shifts are all parts of the VAD signal where a future of `horizon` 312 | frames "overwhelmingly" belongs to a single speaker. The 313 | `majority_ratio` is a threshold over which the ratio of activity must belong to the last/current-speaker. 314 | 315 | Arguments: 316 | vad: torch.Tensor, (B, N, 2) 317 | horizon: int, length in frames of the horizon 318 | majority_ratio: float, ratio of which the majority speaker must occupy 319 | """ 320 | 321 | EPS = 1e-5 # used to avoid nans 322 | 323 | nb = vad.size(0) 324 | 325 | # future windows 326 | vv = vad[:, 1:].unfold(1, size=horizon, step=1).sum(dim=-1) 327 | vv = vv / (vv.sum(-1, keepdim=True) + EPS) 328 | 329 | diff = vad.shape[1] - vv.shape[1] 330 | 331 | if max_frame is not None: 332 | vv = vv[:, :max_frame] 333 | 334 | # Majority_ratio. Add eps to value to not miss majority_ratio==1. 335 | # because we divided 1 336 | maj_speaker_cond = majority_ratio <= (vv + EPS) 337 | 338 | # Last speaker 339 | a_last = last_speaker[:, : maj_speaker_cond.shape[1]] == 0 340 | b_last = last_speaker[:, : maj_speaker_cond.shape[1]] == 1 341 | a_non_shift = torch.logical_and(a_last, maj_speaker_cond[..., 0]) 342 | b_non_shift = torch.logical_and(b_last, maj_speaker_cond[..., 1]) 343 | ns = torch.stack((a_non_shift, b_non_shift), dim=-1).float() 344 | # fill to correct size (same as vad and all other events) 345 | z = torch.zeros((nb, diff, 2), device=ns.device) 346 | non_shift = torch.cat((ns, z), dim=1) 347 | 348 | # Min Context Condition 349 | # i.e. don't use negatives from before `min_context` 350 | if min_context > 0: 351 | non_shift[:, :min_context] = 0.0 352 | return non_shift 353 | 354 | def __call__( 355 | self, 356 | vad, 357 | ds=None, 358 | filled_vad=None, 359 | max_frame=None, 360 | min_context=0, 361 | return_list=False, 362 | ): 363 | 364 | if ds is None: 365 | ds = get_dialog_states(vad) 366 | 367 | if vad.device != self.hold_template.device: 368 | self.shift_template = self.shift_template.to(vad.device) 369 | self.shift_overlap_template = self.shift_overlap_template.to(vad.device) 370 | self.hold_template = self.hold_template.to(vad.device) 371 | 372 | if filled_vad is None: 373 | filled_vad = self.fill_template(vad, ds, self.hold_template) 374 | 375 | shift_oh, pre_shift_oh, long_shift_onset = self.match_template( 376 | filled_vad, 377 | ds, 378 | self.shift_template, 379 | pre_cond_frames=self.pre_offset_shift, 380 | post_cond_frames=self.post_onset_shift, 381 | pre_match=True, 382 | onset_match=True, 383 | max_frame=max_frame, 384 | min_context=min_context, 385 | ) 386 | shift_ov_oh, _, _ = self.match_template( 387 | filled_vad, 388 | ds, 389 | self.shift_overlap_template, 390 | pre_cond_frames=self.pre_offset_shift, 391 | post_cond_frames=self.post_onset_shift, 392 | pre_match=False, 393 | onset_match=False, 394 | max_frame=max_frame, 395 | min_context=min_context, 396 | ) 397 | hold_oh, pre_hold_oh, long_hold_onset = self.match_template( 398 | filled_vad, 399 | ds, 400 | self.hold_template, 401 | pre_cond_frames=self.pre_offset_hold, 402 | post_cond_frames=self.post_onset_hold, 403 | pre_match=True, 404 | onset_match=True, 405 | max_frame=max_frame, 406 | min_context=min_context, 407 | ) 408 | 409 | last_speaker = get_last_speaker(vad, ds) 410 | non_shift_oh = self.non_shifts( 411 | vad, 412 | last_speaker, 413 | horizon=self.non_shift_horizon, 414 | majority_ratio=self.non_shift_majority_ratio, 415 | max_frame=max_frame, 416 | min_context=min_context, 417 | ) 418 | 419 | return { 420 | "shift": shift_oh, 421 | "pre_shift": pre_shift_oh, 422 | "long_shift_onset": long_shift_onset, 423 | "hold": hold_oh, 424 | "pre_hold": pre_hold_oh, 425 | "long_hold_onset": long_hold_onset, 426 | "shift_overlap": shift_ov_oh, 427 | "non_shift": non_shift_oh, 428 | } 429 | 430 | 431 | if __name__ == "__main__": 432 | import matplotlib.pyplot as plt 433 | from vap_turn_taking.plot_utils import plot_vad_oh, plot_event 434 | from vap_turn_taking.config.example_data import event_conf_frames, example 435 | 436 | plt.close("all") 437 | 438 | hs_kwargs = event_conf_frames["hs"] 439 | HS = HoldShift(**hs_kwargs) 440 | tt = HS(example["va"], max_frame=None) 441 | for k, v in tt.items(): 442 | if isinstance(v, torch.Tensor): 443 | print(f"{k}: {tuple(v.shape)}") 444 | else: 445 | print(f"{k}: {v}") 446 | print("shift: ", (example["shift"] != tt["shift"]).sum()) 447 | print("hold: ", (example["hold"] != tt["hold"]).sum()) 448 | 449 | fig, ax = plot_vad_oh(va[0]) 450 | # # _, ax = plot_event(tt["shift"][0], ax=ax) 451 | # _, ax = plot_event(s[0], color=["g", "g"], ax=ax) 452 | # _, ax = plot_event(h[0], color=["r", "r"], ax=ax) 453 | # _, ax = plot_event(bc[0], color=["b", "b"], ax=ax) 454 | # _, ax = plot_event(tt["shift_overlap"][0], ax=ax) 455 | # _, ax = plot_event(tt_bc["backchannel"][0], color=["b", "b"], alpha=0.2, ax=ax) 456 | # _, ax = plot_event(tt_bc["pre_backchannel"][0], alpha=0.2, ax=ax) 457 | # _, ax = plot_event(tt["hold"][0], color=["r", "r"], ax=ax) 458 | # _, ax = plot_event(tt['pre_shift'][0], color=['g', 'g'], alpha=0.2, ax=ax) 459 | # _, ax = plot_event(tt['pre_hold'][0], color=['r', 'r'], alpha=0.2, ax=ax) 460 | # _, ax = plot_event(tt['long_shift_onset'][0], color=['r', 'r'], alpha=0.2, ax=ax) 461 | _, ax = plot_event(tt["non_shift"][0], color=["r", "r"], alpha=0.2, ax=ax) 462 | plt.pause(0.1) 463 | -------------------------------------------------------------------------------- /vap_turn_taking/plot_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import math 5 | from matplotlib.patches import Rectangle 6 | 7 | 8 | def plot_area(oh, ax, label=None, color="b", alpha=1, hatch=None): 9 | ax.fill_between( 10 | torch.arange(oh.shape[0]), 11 | y1=-1, 12 | y2=1, 13 | where=oh, 14 | color=color, 15 | alpha=alpha, 16 | label=label, 17 | hatch=hatch, 18 | ) 19 | 20 | 21 | def plot_projection_window( 22 | proj_win, 23 | bin_frames=None, 24 | ax=None, 25 | colors=["b", "orange"], 26 | yticks=["B", "A"], 27 | ylabel=None, 28 | alpha=1, 29 | label=(None, None), 30 | legend_loc="best", 31 | plot=False, 32 | ): 33 | """ 34 | proj_win: torch.Tensor: (2, N) 35 | """ 36 | fig = None 37 | if ax is None: 38 | fig, ax = plt.subplots(1, 1, figsize=(12, 4)) 39 | 40 | # Create X from bin_frames/oh-bins 41 | if bin_frames is None: 42 | x = torch.arange(proj_win.shape[1]) # 0, 1, 2, 3 43 | xmax = len(x) 44 | bin_lines = x[1:].tolist() 45 | # add last end point n+1 46 | # Using 'step' in plot and 'post' so we must add values to the right to get correct fill 47 | # x: 0, 1, 2, 3, 4 48 | x = torch.cat((x, x[-1:] + 1)) 49 | else: 50 | assert ( 51 | len(bin_frames) == proj_win.shape[1] 52 | ), "len(bin_frames) != proj_win.shape[0]" 53 | if isinstance(bin_frames, torch.Tensor): 54 | x = bin_frames 55 | else: 56 | x = torch.tensor(bin_frames) 57 | # Add first 0 start point 58 | # [0, 20, 40, 60, 80] 59 | x = torch.cat((torch.zeros((1,)), x)) 60 | x = x.cumsum(0) 61 | xmax = x[-1] 62 | bin_lines = x[:-1].long().tolist() 63 | 64 | # add last to match x (which includes 0) 65 | proj_win = torch.cat((proj_win, proj_win[:, -1:]), dim=-1) 66 | 67 | # Fill bins with color 68 | ax.fill_between( 69 | x, 70 | y1=0, 71 | y2=proj_win[0], 72 | step="post", 73 | alpha=alpha, 74 | color=colors[0], 75 | label=label[1], 76 | ) 77 | ax.fill_between( 78 | x, 79 | y1=0, 80 | y2=-proj_win[1], 81 | step="post", 82 | alpha=alpha, 83 | label=label[0], 84 | color=colors[1], 85 | ) 86 | # Add lines separating bins 87 | ax.hlines(y=0, xmin=0, xmax=xmax, color="k", linestyle="dashed") 88 | for step in bin_lines: 89 | ax.vlines(x=step, ymin=-1, ymax=1, color="k") 90 | if label[0] is not None: 91 | ax.legend(loc=legend_loc) 92 | 93 | # set X/Y-ticks/labels 94 | ax.set_xlim([0, xmax]) 95 | ax.set_xticks(x[:-1]) 96 | ax.set_xticklabels(x[:-1].tolist()) 97 | ax.set_ylim([-1.05, 1.05]) 98 | if yticks is None: 99 | ax.set_yticks([]) 100 | else: 101 | ax.set_yticks([-0.5, 0.5]) 102 | ax.set_yticklabels(yticks) 103 | if ylabel is not None: 104 | ax.set_ylabel(ylabel) 105 | 106 | if plot: 107 | plt.tight_layout() 108 | plt.pause(0.1) 109 | return fig, ax 110 | 111 | 112 | def plot_events( 113 | va, 114 | hold=None, 115 | shift=None, 116 | event=None, 117 | event_alpha=0.3, 118 | vad_alpha=0.6, 119 | ax=None, 120 | plot=True, 121 | figsize=(9, 6), 122 | ): 123 | va = va.cpu() 124 | 125 | fig = None 126 | if ax is None: 127 | fig, ax = plt.subplots(1, 1, sharex=True, figsize=figsize) 128 | 129 | _ = plot_vad_oh( 130 | # va, ax=ax.twinx(), alpha=vad_alpha, legend_loc="upper right", label=["B", "A"] 131 | va, 132 | ax, 133 | alpha=vad_alpha, 134 | legend_loc="upper right", 135 | label=["B", "A"], 136 | ) 137 | 138 | # Twin axis for events 139 | if hold is not None or shift is not None or event is not None: 140 | twinax = ax.twinx() 141 | 142 | if hold is not None: 143 | hold = hold.detach().cpu() 144 | plot_area( 145 | hold[:, 0], 146 | ax=twinax, 147 | label="hold -> a", 148 | color="red", 149 | alpha=event_alpha, 150 | hatch="*", 151 | ) 152 | plot_area( 153 | hold[:, 1], ax=twinax, label="hold -> b", color="red", alpha=event_alpha 154 | ) 155 | twinax.legend(loc="upper left") 156 | 157 | if shift is not None: 158 | shift = shift.detach().cpu() 159 | plot_area( 160 | shift[:, 0], 161 | ax=twinax, 162 | label="Shift -> A", 163 | color="green", 164 | alpha=event_alpha, 165 | hatch="*", 166 | ) 167 | plot_area( 168 | shift[:, 1], ax=twinax, label="Shift -> B", color="green", alpha=event_alpha 169 | ) 170 | twinax.legend(loc="upper left") 171 | 172 | if event is not None: 173 | event = event.detach().cpu() 174 | plot_area(event, ax=twinax, label="Event", color="purple", alpha=event_alpha) 175 | twinax.legend(loc="upper left") 176 | 177 | # Twin axis for events 178 | if hold is not None or shift is not None or event is not None: 179 | twinax.set_yticks([]) 180 | twinax.set_ylim([-1.05, 1.05]) 181 | 182 | ax.set_ylim([-1.05, 1.05]) 183 | ax.set_xlim([0, va.shape[0]]) 184 | 185 | if plot: 186 | plt.tight_layout() 187 | plt.pause(0.1) 188 | 189 | return fig, ax 190 | 191 | 192 | ######################################################################### 193 | def plot_backchannel_prediction(va, bc_pred, bc_color="r", linewidth=2, plot=False): 194 | n_frames = bc_pred.shape[1] 195 | if bc_pred.ndim == 3: 196 | B = bc_pred.shape[0] 197 | fig, ax = plt.subplots(B, 1, sharex=True, sharey=True) 198 | for b in range(B): 199 | plot_vad_oh(va[b, :n_frames], ax=ax[b]) 200 | ax[b].plot(bc_pred[b, :, 0], color=bc_color, linewidth=linewidth) 201 | ax[b].plot(-bc_pred[b, :, 1], color=bc_color, linewidth=linewidth) 202 | else: 203 | assert va.ndim == 2, "no batch dimension va must be (N, 2)" 204 | fig, ax = plt.subplots(1, 1, sharex=True, sharey=True) 205 | plot_vad_oh(va[:n_frames], ax=ax) 206 | ax.plot(bc_pred[:, 0], color=bc_color, linewidth=linewidth) 207 | ax.plot(-bc_pred[:, 1], color=bc_color, linewidth=linewidth) 208 | 209 | if plot: 210 | plt.pause(0.01) 211 | 212 | return fig, ax 213 | 214 | 215 | ######################################################################### 216 | # Plot multiple projection windows 217 | ######################################################################### 218 | def plot_all_projection_windows( 219 | proj_wins, bin_times=[0.2, 0.4, 0.6, 0.8], vad_hz=100, figsize=(12, 9), plot=True 220 | ): 221 | bin_frames = [b * vad_hz for b in bin_times] 222 | n = proj_wins.shape[0] 223 | n_cols = 4 224 | n_rows = math.ceil(n / n_cols) 225 | # figsize = (4 * n_cols, 3 * n_rows) 226 | print("n: ", n) 227 | print("(cols,rows): ", n_cols, n_rows) 228 | print("figsize: ", figsize) 229 | if n_rows == 1: 230 | fig, ax = plt.subplots(1, n_cols, sharex=True, sharey=True, figsize=figsize) 231 | for col in range(n_cols): 232 | if col >= n: 233 | break 234 | _ = plot_projection_window( 235 | proj_wins[col], bin_frames=bin_frames, ax=ax[col] 236 | ) 237 | ax[0].set_xticks([]) 238 | ax[0].set_yticks([]) 239 | else: 240 | i = 0 241 | fig, ax = plt.subplots( 242 | n_rows, n_cols, sharex=True, sharey=True, figsize=figsize 243 | ) 244 | for row in range(n_rows): 245 | for col in range(n_cols): 246 | _ = plot_projection_window( 247 | proj_wins[i], bin_frames=bin_frames, ax=ax[row, col] 248 | ) 249 | i += 1 250 | if i >= n: 251 | break 252 | if i >= n: 253 | break 254 | 255 | ax[0, 0].set_xticks([]) 256 | ax[0, 0].set_yticks([]) 257 | plt.tight_layout() 258 | if plot: 259 | plt.pause(0.1) 260 | 261 | return fig, ax 262 | 263 | 264 | ######################################################################### 265 | # Template 266 | ######################################################################### 267 | def plot_template( 268 | projection_type, 269 | prefix_type, 270 | bin_times=[0.2, 0.4, 0.6, 0.8], 271 | alpha_required=0.6, 272 | alpha_optional=0.2, 273 | alpha_prefix=0.6, 274 | lw_box=2, 275 | legend_ts=12, 276 | pad=0.02, 277 | plot=False, 278 | ): 279 | assert projection_type in [ 280 | "shift", 281 | "bc_prediction", 282 | ], "projection type must be in ['shift', 'bc_prediction']" 283 | assert prefix_type in [ 284 | "silence", 285 | "both", 286 | "active", 287 | "overlap", 288 | ], "prefix type must be in ['silence', 'active', 'overlap']" 289 | 290 | colors = ["b", "orange"] 291 | current = 1.5 292 | bins = [current] + (np.array(bin_times).cumsum(0) + current).tolist() 293 | 294 | fig, ax = plt.subplots(1, 1) 295 | handles = [] 296 | # Draw horizontal line (speaker separation) 297 | ax.hlines(y=0, xmin=0, xmax=bins[-1], linewidth=2, color="k") 298 | # current, = ax.plot([current, current], [-1, 1], linewidth=5, color="r", label='current time') 299 | # handles.append(current) 300 | 301 | ####################################################################### 302 | # Draw prefix boxes 303 | ####################################################################### 304 | if prefix_type == "silence": 305 | ax.add_patch( 306 | Rectangle( 307 | xy=(0, -1), 308 | width=1, 309 | height=1, 310 | facecolor=colors[1], 311 | alpha=alpha_prefix, 312 | edgecolor=colors[1], 313 | ) 314 | ) 315 | elif prefix_type == "both": 316 | ax.add_patch( 317 | Rectangle( 318 | xy=(0, -1), 319 | width=1, 320 | height=1, 321 | facecolor=colors[1], 322 | alpha=alpha_prefix, 323 | edgecolor=colors[1], 324 | ) 325 | ) 326 | ax.add_patch( 327 | Rectangle( 328 | xy=(1, -1), 329 | width=0.5, 330 | height=1, 331 | facecolor=colors[1], 332 | alpha=alpha_optional + 0.05, 333 | edgecolor=colors[1], 334 | ) 335 | ) 336 | else: 337 | ax.add_patch( 338 | Rectangle( 339 | xy=(0, -1), 340 | width=1.5, 341 | height=1, 342 | facecolor=colors[1], 343 | alpha=alpha_prefix, 344 | edgecolor=colors[1], 345 | ) 346 | ) 347 | # ax.add_patch( 348 | # Rectangle( 349 | # xy=(0, -1), 350 | # width=1, 351 | # height=1, 352 | # facecolor=colors[1], 353 | # alpha=alpha_prefix, 354 | # edgecolor=colors[1], 355 | # ) 356 | # ) 357 | # ax.add_patch( 358 | # Rectangle( 359 | # xy=(1, -1), 360 | # width=0.5, 361 | # height=1, 362 | # facecolor=colors[1], 363 | # alpha=alpha_optional+0.05, 364 | # edgecolor=colors[1], 365 | # ) 366 | # ) 367 | if prefix_type == "overlap": 368 | ax.add_patch( 369 | Rectangle( 370 | xy=(1, 0), 371 | width=0.5, 372 | height=1, 373 | facecolor=colors[0], 374 | edgecolor=colors[0], 375 | ) 376 | ) 377 | 378 | ####################################################################### 379 | # Projection Window Template 380 | ####################################################################### 381 | if projection_type == "shift": 382 | # Optional 383 | optional_a = Rectangle( 384 | xy=(current, 0), 385 | width=bins[2] - bins[0], 386 | height=1, 387 | label="A optional", 388 | facecolor=colors[0], 389 | alpha=alpha_optional, 390 | edgecolor=colors[0], 391 | ) 392 | handles.append(optional_a) 393 | ax.add_patch(optional_a) 394 | 395 | # Required 396 | required_a = Rectangle( 397 | xy=(bins[2], 0), 398 | width=bins[-1] - bins[2], 399 | height=1, 400 | label="A required", 401 | facecolor=colors[0], 402 | alpha=alpha_required, 403 | hatch=".", 404 | edgecolor=colors[0], 405 | ) 406 | handles.append(required_a) 407 | ax.add_patch(required_a) 408 | if prefix_type != "silence": 409 | # Optional B 410 | optional_b = Rectangle( 411 | xy=(current, -1), 412 | width=bins[2] - bins[0], 413 | height=1, 414 | label="B optional", 415 | facecolor=colors[1], 416 | alpha=alpha_optional, 417 | edgecolor=colors[1], 418 | ) 419 | handles.append(optional_b) 420 | ax.add_patch(optional_b) 421 | elif projection_type == "bc_prediction": 422 | # A at least one 423 | a_one = Rectangle( 424 | xy=(current, 0), 425 | width=bins[3] - bins[0], 426 | height=1, 427 | label="A at least 1", 428 | facecolor=colors[0], 429 | alpha=0.2, 430 | edgecolor=colors[0], 431 | hatch="//", 432 | ) 433 | handles.append(a_one) 434 | ax.add_patch(a_one) 435 | # B optional 436 | optional_b = Rectangle( 437 | xy=(current, -1), 438 | width=bins[3] - bins[0], 439 | height=1, 440 | label="B optional", 441 | facecolor=colors[1], 442 | alpha=alpha_optional, 443 | edgecolor=colors[1], 444 | ) 445 | handles.append(optional_b) 446 | ax.add_patch(optional_b) 447 | 448 | # B required 449 | required_b = Rectangle( 450 | xy=(bins[3], -1), 451 | width=bins[-1] - bins[3], 452 | height=1, 453 | label="B required", 454 | facecolor=colors[1], 455 | alpha=alpha_required, 456 | hatch=".", 457 | edgecolor=colors[1], 458 | ) 459 | handles.append(required_b) 460 | ax.add_patch(required_b) 461 | 462 | # Draw lines for projection window template 463 | ax.vlines(bins[1:-1], ymin=-1, ymax=1, linewidth=lw_box, color="k") 464 | ax.hlines(y=[-1, 1], xmin=1.5, xmax=bins[-1], linewidth=lw_box, color="k") 465 | ax.vlines(bins[-1], ymin=-1, ymax=1, linewidth=lw_box, color="k") 466 | # Draw current line 467 | ax.vlines(current, ymin=-1, ymax=1, linewidth=5, color="r", label="current time") 468 | 469 | ax.set_yticks([]) 470 | ax.set_xticks([]) 471 | ax.set_xlim([0, bins[-1] + 0.05]) 472 | ax.legend(loc="upper left", handles=handles, fontsize=legend_ts) 473 | 474 | plt.subplots_adjust( 475 | left=pad, bottom=pad, right=1 - pad, top=1 - pad, wspace=None, hspace=None 476 | ) 477 | if plot: 478 | plt.pause(0.1) 479 | return fig, ax 480 | 481 | 482 | ######################################################################################## 483 | # New 484 | def plot_vad_oh( 485 | vad_oh, 486 | ax=None, 487 | colors=["b", "orange"], 488 | yticks=["B", "A"], 489 | ylabel=None, 490 | alpha=1, 491 | label=(None, None), 492 | legend_loc="best", 493 | plot=False, 494 | ): 495 | """ 496 | vad_oh: torch.Tensor: (N, 2) 497 | """ 498 | fig = None 499 | if ax is None: 500 | fig, ax = plt.subplots(1, 1, figsize=(12, 4)) 501 | 502 | x = torch.arange(vad_oh.shape[0]) + 0.5 # fill_between step = 'mid' 503 | ax.fill_between( 504 | x, 505 | y1=0, 506 | y2=vad_oh[:, 0], 507 | step="mid", 508 | alpha=alpha, 509 | color=colors[0], 510 | label=label[1], 511 | ) 512 | ax.fill_between( 513 | x, 514 | y1=0, 515 | y2=-vad_oh[:, 1], 516 | step="mid", 517 | alpha=alpha, 518 | label=label[0], 519 | color=colors[1], 520 | ) 521 | if label[0] is not None: 522 | ax.legend(loc=legend_loc) 523 | ax.hlines(y=0, xmin=0, xmax=len(x), color="k", linestyle="dashed") 524 | ax.set_xlim([0, vad_oh.shape[0]]) 525 | ax.set_ylim([-1.05, 1.05]) 526 | 527 | if yticks is None: 528 | ax.set_yticks([]) 529 | else: 530 | ax.set_yticks([-0.5, 0.5]) 531 | ax.set_yticklabels(yticks) 532 | if ylabel is not None: 533 | ax.set_ylabel(ylabel) 534 | 535 | plt.tight_layout() 536 | if plot: 537 | plt.pause(0.1) 538 | return fig, ax 539 | 540 | 541 | def plot_event(ev, label=[None, None], color=["g", "g"], alpha=0.5, ax=None): 542 | assert ev.ndim == 2, "Must provide event of (N, 2)" 543 | fig = None 544 | if ax is None: 545 | fig, ax = plt.subplots(1, 1) 546 | for speaker in range(ev.shape[-1]): 547 | plot_area( 548 | ev[:, speaker], 549 | ax=ax, 550 | label=label[speaker], 551 | color=color[speaker], 552 | alpha=alpha, 553 | ) 554 | return fig, ax 555 | 556 | 557 | if __name__ == "__main__": 558 | 559 | from vap_turn_taking.config.example_data import event_conf, example 560 | from vap_turn_taking import TurnTakingEvents 561 | 562 | eventer = TurnTakingEvents( 563 | hs_kwargs=event_conf["hs"], 564 | bc_kwargs=event_conf["bc"], 565 | metric_kwargs=event_conf["metric"], 566 | frame_hz=100, 567 | ) 568 | 569 | va = example["va"] 570 | events = eventer(va, max_frame=None) 571 | 572 | fig, ax = plot_vad_oh(va[0]) 573 | # _, ax = plot_event(events["shift"][0], ax=ax) 574 | _, ax = plot_event(events["hold"][0], color=["r", "r"], ax=ax) 575 | # _, ax = plot_event(events["short"][0], ax=ax) 576 | # _, ax = plot_event(events["long"][0], color=['r', 'r'], ax=ax) 577 | # _, ax = plot_event(example['short'][0], color=["g", "g"], ax=ax) 578 | # _, ax = plot_event(example['long'][0], color=["r", "r"], ax=ax) 579 | _, ax = plot_event(example["hold"][0], color=["b", "b"], ax=ax) 580 | # _, ax = plot_event(example['shift'][0], color=["g", "g"], ax=ax) 581 | # _, ax = plot_event(example['short'][0], color=["r", "r"], ax=ax) 582 | # _, ax = plot_event(example['long'][0], color=["r", "r"], ax=ax) 583 | # _, ax = plot_event(bc[0], color=["b", "b"], ax=ax) 584 | # _, ax = plot_event(tt["shift_overlap"][0], ax=ax) 585 | # _, ax = plot_event(events["short"][0], color=["b", "b"], alpha=0.2, ax=ax) 586 | # _, ax = plot_event(tt_bc["pre_backchannel"][0], alpha=0.2, ax=ax) 587 | # _, ax = plot_event(tt["hold"][0], color=["r", "r"], ax=ax) 588 | # _, ax = plot_event(tt['pre_shift'][0], color=['g', 'g'], alpha=0.2, ax=ax) 589 | # _, ax = plot_event(tt['pre_hold'][0], color=['r', 'r'], alpha=0.2, ax=ax) 590 | # _, ax = plot_event(tt['long_shift_onset'][0], color=['r', 'r'], alpha=0.2, ax=ax) 591 | # _, ax = plot_event(events["non_shift"][0], color=["r", "r"], alpha=0.2, ax=ax) 592 | plt.pause(0.1) 593 | -------------------------------------------------------------------------------- /vap_turn_taking/vap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from typing import List 5 | 6 | from vap_turn_taking.utils import vad_to_dialog_vad_states 7 | 8 | 9 | def probs_ind_backchannel(probs): 10 | """ 11 | 12 | Extract the probabilities associated with 13 | 14 | A: |__|--|--|__| 15 | B: |__|__|__|--| 16 | 17 | """ 18 | bc_pred = [] 19 | 20 | # Iterate over speakers 21 | for current_speaker, backchanneler in zip([1, 0], [0, 1]): 22 | # Between speaker diff 23 | # -------------------- 24 | # Is the last bin of the "backchanneler" less probable than last bin of current speaker? 25 | last_a_lt_b = probs[..., backchanneler, -1] < probs[..., current_speaker, -1] 26 | 27 | # Within speaker diff 28 | # -------------------- 29 | # Is the start/middle bins of the "backchanneler" greater than the last bin? 30 | # I.e. does it predict an ending response? 31 | mid_a_max, _ = probs[..., backchanneler, :-1].max( 32 | dim=-1 33 | ) # get prob (used with threshold) 34 | mid_a_gt_last = ( 35 | mid_a_max > probs[..., backchanneler, -1] 36 | ) # find where the condition is true 37 | 38 | # Combine the between/within conditions 39 | non_zero_probs = torch.logical_and(last_a_lt_b, mid_a_gt_last) 40 | 41 | # Create probability tensor 42 | # P=0 where conditions are not met 43 | # P=max activation where condition is True 44 | tmp_pred_probs = torch.zeros_like(mid_a_max) 45 | tmp_pred_probs[non_zero_probs] = mid_a_max[non_zero_probs] 46 | bc_pred.append(tmp_pred_probs) 47 | 48 | bc_pred = torch.stack(bc_pred, dim=-1) 49 | return bc_pred 50 | 51 | 52 | def bin_times_to_frames(bin_times, frame_hz): 53 | bt = torch.tensor(bin_times) 54 | return (bt * frame_hz).long().tolist() 55 | 56 | 57 | class WindowHelper: 58 | @staticmethod 59 | def all_permutations_mono(n, start=0): 60 | vectors = [] 61 | for i in range(start, 2 ** n): 62 | i = bin(i).replace("0b", "").zfill(n) 63 | tmp = torch.zeros(n) 64 | for j, val in enumerate(i): 65 | tmp[j] = float(val) 66 | vectors.append(tmp) 67 | return torch.stack(vectors) 68 | 69 | @staticmethod 70 | def end_of_segment_mono(n, max=3): 71 | """ 72 | # 0, 0, 0, 0 73 | # 1, 0, 0, 0 74 | # 1, 1, 0, 0 75 | # 1, 1, 1, 0 76 | """ 77 | v = torch.zeros((max + 1, n)) 78 | for i in range(max): 79 | v[i + 1, : i + 1] = 1 80 | return v 81 | 82 | @staticmethod 83 | def on_activity_change_mono(n=4, min_active=2): 84 | """ 85 | 86 | Used where a single speaker is active. This vector (single speaker) represents 87 | the classes we use to infer that the current speaker will end their activity 88 | and the other take over. 89 | 90 | the `min_active` variable corresponds to the minimum amount of frames that must 91 | be active AT THE END of the projection window (for the next active speaker). 92 | This used to not include classes where the activity may correspond to a short backchannel. 93 | e.g. if only the last bin is active it may be part of just a short backchannel, if we require 2 bins to 94 | be active we know that the model predicts that the continuation will be at least 2 bins long and thus 95 | removes the ambiguouty (to some extent) about the prediction. 96 | """ 97 | 98 | base = torch.zeros(n) 99 | # force activity at the end 100 | if min_active > 0: 101 | base[-min_active:] = 1 102 | 103 | # get all permutations for the remaining bins 104 | permutable = n - min_active 105 | if permutable > 0: 106 | perms = WindowHelper.all_permutations_mono(permutable) 107 | base = base.repeat(perms.shape[0], 1) 108 | base[:, :permutable] = perms 109 | return base 110 | 111 | @staticmethod 112 | def combine_speakers(x1, x2, mirror=False): 113 | if x1.ndim == 1: 114 | x1 = x1.unsqueeze(0) 115 | if x2.ndim == 1: 116 | x2 = x2.unsqueeze(0) 117 | vad = [] 118 | for a in x1: 119 | for b in x2: 120 | vad.append(torch.stack((a, b), dim=0)) 121 | 122 | vad = torch.stack(vad) 123 | if mirror: 124 | vad = torch.stack((vad, torch.stack((vad[:, 1], vad[:, 0]), dim=1))) 125 | return vad 126 | 127 | 128 | class VAPLabel(nn.Module): 129 | def __init__( 130 | self, 131 | bin_times: List = [0.2, 0.4, 0.6, 0.8], 132 | frame_hz: int = 100, 133 | threshold_ratio: float = 0.5, 134 | ): 135 | super().__init__() 136 | self.bin_times = bin_times 137 | self.frame_hz = frame_hz 138 | self.threshold_ratio = threshold_ratio 139 | 140 | self.bin_frames = bin_times_to_frames(bin_times, frame_hz) 141 | self.n_bins = len(self.bin_frames) 142 | self.total_bins = self.n_bins * 2 143 | self.horizon = sum(self.bin_frames) 144 | 145 | def __repr__(self) -> str: 146 | s = "VAPLabel(\n" 147 | s += f" bin_times: {self.bin_times}\n" 148 | s += f" bin_frames: {self.bin_frames}\n" 149 | s += f" frame_hz: {self.frame_hz}\n" 150 | s += f" thresh: {self.threshold_ratio}\n" 151 | s += ")\n" 152 | return s 153 | 154 | def projection(self, va): 155 | """ 156 | Extract projection (bins) 157 | (b, n, c) -> (b, N, c, M), M=horizon window size, N=valid frames 158 | 159 | Arguments: 160 | va: torch.Tensor (B, N, C) 161 | 162 | Returns: 163 | vaps: torch.Tensor (B, m, C, M) 164 | 165 | """ 166 | # Shift to get next frame projections 167 | return va[..., 1:, :].unfold(dimension=-2, size=sum(self.bin_frames), step=1) 168 | 169 | def vap_bins(self, projection_window): 170 | """ 171 | Iterate over the bin boundaries and sum the activity 172 | for each channel/speaker. 173 | divide by the number of frames to get activity ratio. 174 | If ratio is greater than or equal to the threshold_ratio 175 | the bin is considered active 176 | """ 177 | 178 | start = 0 179 | v_bins = [] 180 | for b in self.bin_frames: 181 | end = start + b 182 | m = projection_window[..., start:end].sum(dim=-1) / b 183 | m = (m >= self.threshold_ratio).float() 184 | v_bins.append(m) 185 | start = end 186 | return torch.stack(v_bins, dim=-1) # (*, t, c, n_bins) 187 | 188 | def comparative(self, projection_window): 189 | """ 190 | Sum together the activity for each speaker in the `projection_window` and get the activity 191 | ratio for each speaker (focused on speaker 0) 192 | p(speaker_1) = 1 - p(speaker_0) 193 | vad: torch.tensor, (B, N, 2) 194 | comp: torch.tensor, (B, N) 195 | """ 196 | comp = projection_window.sum(dim=-1) # sum all activity for speakers 197 | tot = comp.sum(dim=-1) + 1e-9 # get total activity 198 | # focus on speaker 0 and get ratio: p(speaker_1)= 1 - p(speaker_0) 199 | return comp[..., 0] / tot 200 | 201 | def __call__(self, va: torch.Tensor, type: str = "binary") -> torch.Tensor: 202 | projection_windows = self.projection(va) 203 | 204 | if type == "comparative": 205 | return self.comparative(projection_windows) 206 | 207 | return self.vap_bins(projection_windows) 208 | 209 | 210 | class ActivityEmb(nn.Module): 211 | def __init__(self, bin_times=[0.20, 0.40, 0.60, 0.80], frame_hz=100): 212 | super().__init__() 213 | self.frame_hz = frame_hz 214 | self.bin_frames = bin_times_to_frames(bin_times, frame_hz) 215 | self.n_bins = len(self.bin_frames) 216 | self.total_bins = self.n_bins * 2 217 | self.n_classes = 2 ** self.total_bins 218 | 219 | # Discrete indices 220 | if self.n_bins <= 5: 221 | self.codebook = self.init_codebook() 222 | 223 | # weighted by bin size (subset for active/silent is modified dependent on bin_frames) 224 | wsil, wact = self.init_subset_weighted_by_bin_size() 225 | self.subset_bin_weighted_silence = wsil 226 | self.subset_bin_weighted_active = wact 227 | 228 | self.subset_silence, self.subset_silence_hold = self.init_subset_silence() 229 | self.subset_active, self.subset_active_hold = self.init_subset_active() 230 | self.bc_prediction = self.init_subset_backchannel() 231 | self.requires_grad_(False) 232 | 233 | def init_codebook(self) -> nn.Module: 234 | """ 235 | Initializes the codebook for the vad-projection horizon labels. 236 | 237 | Map all vectors of binary digits of length `n_bins` to their corresponding decimal value. 238 | 239 | This allows a VAD future of shape (*, 4, 2) to be flatten to (*, 8) and mapped to a number 240 | corresponding to the class index. 241 | """ 242 | 243 | def single_idx_to_onehot(idx, d=8): 244 | assert idx < 2 ** d, "must be possible with {d} binary digits" 245 | z = torch.zeros(d) 246 | b = bin(idx).replace("0b", "") 247 | for i, v in enumerate(b[::-1]): 248 | z[i] = float(v) 249 | return z 250 | 251 | def create_code_vectors(n_bins): 252 | """ 253 | Create a matrix of all one-hot encodings representing a binary sequence of `self.total_bins` places 254 | Useful for usage in `nn.Embedding` like module. 255 | """ 256 | n_codes = 2 ** n_bins 257 | embs = torch.zeros((n_codes, n_bins)) 258 | for i in range(2 ** n_bins): 259 | embs[i] = single_idx_to_onehot(i, d=n_bins) 260 | return embs 261 | 262 | print("n_classes: ", self.n_classes) 263 | print("total_bins: ", self.total_bins) 264 | codebook = nn.Embedding( 265 | num_embeddings=self.n_classes, embedding_dim=self.total_bins 266 | ) 267 | codebook.weight.data = create_code_vectors(self.total_bins) 268 | codebook.weight.requires_grad_(False) 269 | return codebook 270 | 271 | def init_subset_weighted_by_bin_size(self): 272 | """ 273 | Calculates the comparative probability between the activity in each window for each speaker. 274 | 275 | a = sum(scale*activity_speaker_a) 276 | b = sum(scale*activity_speaker_b) 277 | p_a = a / (a+b) 278 | p_b = 1 - p_a 279 | """ 280 | 281 | def oh_to_prob(oh): 282 | tot = oh.sum(dim=-1).sum(dim=-1) 283 | a_comp = oh[:, 0].sum(-1) / (tot + 1e-9) 284 | # No activity counts as equal 285 | a_comp[a_comp == 0] = 0.5 286 | b_comp = 1 - a_comp 287 | return torch.stack((a_comp, b_comp), dim=-1) 288 | 289 | # get all onehot-states 290 | idx = torch.arange(self.n_classes) 291 | 292 | # normalize bin size weights -> adds to one 293 | scale_bins = torch.tensor(self.bin_frames, dtype=torch.float) 294 | scale_bins /= scale_bins.sum() 295 | 296 | # scale the bins of the onehot-states 297 | oh = scale_bins * self.idx_to_onehot(idx) 298 | subset_silence = oh_to_prob(oh) 299 | subset_active = oh_to_prob(oh[..., 2:]) 300 | return subset_silence, subset_active 301 | 302 | def sort_idx(self, x): 303 | if x.ndim == 1: 304 | x, _ = x.sort() 305 | elif x.ndim == 2: 306 | if x.shape[0] == 2: 307 | a, _ = x[0].sort() 308 | b, _ = x[1].sort() 309 | x = torch.stack((a, b)) 310 | else: 311 | x, _ = x[0].sort() 312 | x = x.unsqueeze(0) 313 | return x 314 | 315 | def init_subset_silence(self): 316 | """ 317 | During mutual silences we wish to infer which speaker the model deems most likely. 318 | 319 | We focus on classes where only a single speaker is active in the projection window, 320 | renormalize the probabilities on this subset, and determine which speaker is the most 321 | likely next speaker. 322 | """ 323 | 324 | # active channel: At least 1 bin is active -> all permutations (all except the no-activity) 325 | # active = self._all_permutations_mono(n, start=1) # at least 1 active 326 | # active channel: At least 1 bin is active -> all permutations (all except the no-activity) 327 | active = WindowHelper.on_activity_change_mono(self.n_bins, min_active=2) 328 | # non-active channel: zeros 329 | non_active = torch.zeros((1, active.shape[-1])) 330 | # combine 331 | shift_oh = WindowHelper.combine_speakers(active, non_active, mirror=True) 332 | shift = self.onehot_to_idx(shift_oh) 333 | shift = self.sort_idx(shift) 334 | 335 | # symmetric, this is strictly unneccessary but done for convenience and to be similar 336 | # to 'get_on_activity_shift' where we actually have asymmetric classes for hold/shift 337 | hold = shift.flip(0) 338 | return shift, hold 339 | 340 | def init_subset_active(self): 341 | """On activity""" 342 | # Shift subset 343 | eos = WindowHelper.end_of_segment_mono(self.n_bins, max=2) 344 | nav = WindowHelper.on_activity_change_mono(self.n_bins, min_active=2) 345 | shift_oh = WindowHelper.combine_speakers(nav, eos, mirror=True) 346 | shift = self.onehot_to_idx(shift_oh) 347 | shift = self.sort_idx(shift) 348 | 349 | # Don't shift subset 350 | eos = WindowHelper.on_activity_change_mono(self.n_bins, min_active=2) 351 | zero = torch.zeros((1, self.n_bins)) 352 | hold_oh = WindowHelper.combine_speakers(zero, eos, mirror=True) 353 | hold = self.onehot_to_idx(hold_oh) 354 | hold = self.sort_idx(hold) 355 | return shift, hold 356 | 357 | def init_subset_backchannel(self, n=4): 358 | if n != 4: 359 | raise NotImplementedError("Not implemented for bin-size != 4") 360 | 361 | # at least 1 bin active over 3 bins 362 | bc_speaker = WindowHelper.all_permutations_mono(n=3, start=1) 363 | bc_speaker = torch.cat( 364 | (bc_speaker, torch.zeros((bc_speaker.shape[0], 1))), dim=-1 365 | ) 366 | 367 | # all permutations of 3 bins 368 | current = WindowHelper.all_permutations_mono(n=3, start=0) 369 | current = torch.cat((current, torch.ones((current.shape[0], 1))), dim=-1) 370 | 371 | bc_both = WindowHelper.combine_speakers(bc_speaker, current, mirror=True) 372 | return self.onehot_to_idx(bc_both) 373 | 374 | def onehot_to_idx(self, x) -> torch.Tensor: 375 | """ 376 | The inverse of the 'forward' function. 377 | 378 | Arguments: 379 | x: torch.Tensor (*, 2, 4) 380 | 381 | Inspiration for distance calculation: 382 | https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py 383 | """ 384 | assert x.shape[-2:] == (2, self.n_bins) 385 | 386 | # compare with codebook and get closest idx 387 | shape = x.shape 388 | flatten = rearrange(x, "... c bpp -> (...) (c bpp)", c=2, bpp=self.n_bins) 389 | embed = self.codebook.weight.t() 390 | 391 | dist = -( 392 | flatten.pow(2).sum(1, keepdim=True) 393 | - 2 * flatten @ embed 394 | + embed.pow(2).sum(0, keepdim=True) 395 | ) 396 | 397 | embed_ind = dist.max(dim=-1).indices 398 | embed_ind = embed_ind.view(*shape[:-2]) 399 | return embed_ind 400 | 401 | def idx_to_onehot(self, idx): 402 | v = self.codebook(idx) 403 | return rearrange(v, "... (c b) -> ... c b", c=2) 404 | 405 | def forward(self, projection_window): 406 | return self.onehot_to_idx(projection_window) 407 | 408 | 409 | class Probabilites: 410 | def _normalize_ind_probs(self, probs): 411 | probs = probs.sum(dim=-1) # sum all bins for each speaker 412 | return probs / probs.sum(dim=-1, keepdim=True) # norm 413 | 414 | def _marginal_probs(self, probs, pos_idx, neg_idx): 415 | p = [] 416 | for next_speaker in [0, 1]: 417 | joint = torch.cat((pos_idx[next_speaker], neg_idx[next_speaker]), dim=-1) 418 | p_sum = probs[..., joint].sum(dim=-1) 419 | p.append(probs[..., pos_idx[next_speaker]].sum(dim=-1) / p_sum) 420 | return torch.stack(p, dim=-1) 421 | 422 | def _silence_probs(self, p_a, p_b, sil_probs, silence): 423 | w = torch.where(silence) 424 | p_a[w] = sil_probs[w][..., 0] 425 | p_b[w] = sil_probs[w][..., 1] 426 | return p_a, p_b 427 | 428 | def _single_speaker_probs(self, p0, p1, act_probs, current, other_speaker): 429 | w = torch.where(current) 430 | p0[w] = 1 - act_probs[w][..., other_speaker] # P_a = 1-P_b 431 | p1[w] = act_probs[w][..., other_speaker] # P_b 432 | return p0, p1 433 | 434 | def _overlap_probs(self, p_a, p_b, act_probs, both): 435 | """ 436 | P_a_prior=A is next (active) 437 | P_b_prior=B is next (active) 438 | We the compare/renormalize given the two values of A/B is the next speaker 439 | sum = P_a_prior+P_b_prior 440 | P_a = P_a_prior / sum 441 | P_b = P_b_prior / sum 442 | """ 443 | w = torch.where(both) 444 | # Re-Normalize and compare next-active 445 | sum = act_probs[w][..., 0] + act_probs[w][..., 1] 446 | 447 | p_a[w] = act_probs[w][..., 0] / sum 448 | p_b[w] = act_probs[w][..., 1] / sum 449 | return p_a, p_b 450 | 451 | 452 | class VAP(nn.Module, Probabilites): 453 | TYPES = ["discrete", "independent", "comparative"] 454 | 455 | def __init__( 456 | self, 457 | type="discrete", 458 | bin_times=[0.20, 0.40, 0.60, 0.80], 459 | frame_hz=100, 460 | pre_frames=2, 461 | threshold_ratio=0.5, 462 | ): 463 | super().__init__() 464 | assert type in VAP.TYPES, "{type} is not a valid type! {VAP.TYPES}" 465 | 466 | self.type = type 467 | self.frame_hz = frame_hz 468 | self.bin_times = bin_times 469 | self.emb = ActivityEmb(bin_times, frame_hz) 470 | self.vap_label = VAPLabel(bin_times, frame_hz, threshold_ratio) 471 | self.horizon = torch.tensor(self.bin_times).sum(0).item() 472 | self.horizon_frames = int(self.horizon * frame_hz) 473 | self.pre_frames = pre_frames 474 | 475 | @property 476 | def vap_bins(self): 477 | n = torch.arange(self.emb.n_classes, device=self.emb.codebook.weight.device) 478 | return self.emb.idx_to_onehot(n) 479 | 480 | def __repr__(self): 481 | s = super().__repr__().split("\n") 482 | s.insert(1, f" type: {self.type}") 483 | s = "\n".join(s) 484 | return s 485 | 486 | def _probs_on_silence(self, probs): 487 | return self._marginal_probs( 488 | probs, self.emb.subset_silence, self.emb.subset_silence_hold 489 | ) 490 | 491 | def _probs_on_active(self, probs): 492 | return self._marginal_probs( 493 | probs, self.emb.subset_active, self.emb.subset_active_hold 494 | ) 495 | 496 | def _probs_ind_on_silence(self, probs): 497 | return self._normalize_ind_probs(probs) 498 | 499 | def _probs_ind_on_active(self, probs): 500 | return self._normalize_ind_probs(probs[..., :, self.pre_frames :]) 501 | 502 | def _probs_weighted_on_silence(self, probs): 503 | sil_probs = probs.unsqueeze( 504 | -1 505 | ) * self.emb.subset_bin_weighted_silence.unsqueeze(0).to(probs.device) 506 | return sil_probs.sum(dim=-2) # sum over classes 507 | 508 | def _probs_weighted_on_active(self, probs): 509 | # comparative active 510 | act_probs = probs.unsqueeze(-1) * self.emb.subset_bin_weighted_active.unsqueeze( 511 | 0 512 | ).to(probs.device) 513 | return act_probs.sum(dim=-2) # sum over classes 514 | 515 | def probs_backchannel(self, probs): 516 | ap = probs[..., self.emb.bc_prediction[0]].sum(-1) 517 | bp = probs[..., self.emb.bc_prediction[1]].sum(-1) 518 | return torch.stack((ap, bp), dim=-1) 519 | 520 | def probs_next_speaker(self, probs, va, type): 521 | """ 522 | Extracts the probabilities for who the next speaker is dependent on what the current moment is. 523 | 524 | This means that on mutual silences we use the 'silence'-subset, 525 | where a single speaker is active we use the 'active'-subset and where on overlaps 526 | """ 527 | if type == "independent": 528 | sil_probs = self._probs_ind_on_silence(probs) 529 | act_probs = self._probs_ind_on_active(probs) 530 | elif type == "weighted": 531 | sil_probs = self._probs_weighted_on_silence(probs) 532 | act_probs = self._probs_weighted_on_active(probs) 533 | elif type == "comparative": 534 | pB = 1 - probs 535 | sil_probs = act_probs = torch.cat((probs, pB), dim=-1) 536 | else: # discrete 537 | sil_probs = self._probs_on_silence(probs) 538 | act_probs = self._probs_on_active(probs) 539 | 540 | # Start wit all zeros 541 | # p_a: probability of A being next speaker (channel: 0) 542 | # p_b: probability of B being next speaker (channel: 1) 543 | p_a = torch.zeros_like(va[..., 0]) 544 | p_b = torch.zeros_like(va[..., 0]) 545 | 546 | # dialog states 547 | ds = vad_to_dialog_vad_states(va) 548 | silence = ds == 1 549 | a_current = ds == 0 550 | b_current = ds == 3 551 | both = ds == 2 552 | 553 | # silence 554 | p_a, p_b = self._silence_probs(p_a, p_b, sil_probs, silence) 555 | 556 | # A current speaker 557 | # Given only A is speaking we use the 'active' probability of B being the next speaker 558 | p_a, p_b = self._single_speaker_probs( 559 | p_a, p_b, act_probs, a_current, other_speaker=1 560 | ) 561 | 562 | # B current speaker 563 | # Given only B is speaking we use the 'active' probability of A being the next speaker 564 | p_b, p_a = self._single_speaker_probs( 565 | p_b, p_a, act_probs, b_current, other_speaker=0 566 | ) 567 | 568 | # Both 569 | p_a, p_b = self._overlap_probs(p_a, p_b, act_probs, both) 570 | 571 | p_probs = torch.stack((p_a, p_b), dim=-1) 572 | return p_probs 573 | 574 | def extract_label(self, va: torch.Tensor) -> torch.Tensor: 575 | if self.type == "comparative": 576 | return self.vap_label(va, type="comparative") 577 | 578 | vap_bins = self.vap_label(va, type="binary") 579 | 580 | if self.type == "independent": 581 | return vap_bins 582 | 583 | return self.emb(vap_bins) # discrete 584 | 585 | def forward(self, logits, va): 586 | """ 587 | Probabilites for events dependent on VAP-embedding and VA. 588 | """ 589 | 590 | # Next speaker probs 591 | if self.type == "discrete": 592 | probs = logits.softmax(dim=-1) 593 | p = self.probs_next_speaker(probs=probs, va=va, type=self.type) 594 | p_bc = self.probs_backchannel(probs) 595 | else: 596 | probs = logits.sigmoid() 597 | p = self.probs_next_speaker(probs=probs, va=va, type=self.type) 598 | p_bc = None # comparative 599 | if self.type == "independent": 600 | # Backchannel probs (dependent on embedding and VA) 601 | p_bc = probs_ind_backchannel(probs) 602 | 603 | return {"p": p, "bc_prediction": p_bc} 604 | 605 | 606 | if __name__ == "__main__": 607 | from vap_turn_taking.config.example_data import example, event_conf 608 | 609 | vapper = VAP(type="comparative") 610 | va = example["va"] 611 | y = vapper.extract_label(va) 612 | print("va: ", tuple(va.shape)) 613 | print("y: ", tuple(y.shape)) 614 | vapper 615 | --------------------------------------------------------------------------------