├── .amlignore ├── .devcontainer └── devcontainer.json ├── .gitattributes ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DATA_LICENSE ├── Dockerfile ├── LICENSE ├── README.md ├── SECURITY.md ├── Transparency_FAQ.md ├── asr └── asr.py ├── configs ├── conf_demo.yaml ├── inference │ ├── debug_inference.yaml │ ├── diarization │ │ └── nemo │ │ │ ├── diar_infer_general.yaml │ │ │ ├── diar_infer_meeting.yaml │ │ │ └── diar_infer_telephonic.yaml │ └── inference_v1.yaml └── train_css │ └── local │ ├── conformer_v0.51_mc.yaml │ ├── conformer_v0.51_sc.yaml │ ├── conformer_v0.5_mc.yaml │ ├── conformer_v0.5_sc.yaml │ ├── conformer_v1.0_mc.yaml │ ├── conformer_v1.0_sc.yaml │ ├── debug_mc.yaml │ └── debug_sc.yaml ├── css ├── css.py ├── css_with_conformer │ ├── README.md │ ├── executor │ │ ├── __init__.py │ │ ├── executor.py │ │ └── feature.py │ ├── nnet │ │ ├── __init__.py │ │ └── conformer.py │ ├── separate.py │ └── utils │ │ ├── __init__.py │ │ ├── audio_util.py │ │ ├── mvdr_util.py │ │ ├── overlapped_speech_1ch.scp │ │ └── overlapped_speech_7ch.scp ├── helpers.py └── training │ ├── augmentations.py │ ├── conformer_wrapper.py │ ├── losses.py │ ├── schedulers.py │ ├── simulated_dataset.py │ └── train.py ├── diarization ├── diarization.py ├── diarization_common.py ├── time_based_diarization.py └── word_based_diarization.py ├── inference_pipeline ├── inference.py └── load_meeting_data.py ├── requirements.txt ├── run_inference.py ├── run_training_css_local.py ├── sample_data └── css_train_set │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_noise │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_activity_scores │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_direct_early_echoes │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_reverb │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.json │ ├── 0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.mixture │ └── dataset-000000.map └── utils ├── audio_utils.py ├── azure_storage.py ├── conf.py ├── hugging_face_helper.py ├── logging_def.py ├── mic_array_model.py ├── notsofar_dataset.py ├── numpy_utils.py ├── plot_utils.py ├── results_analysis.py ├── scoring.py ├── text_norm_whisper_like ├── LICENSE ├── __init__.py ├── basic.py ├── english.json ├── english.py └── pre_english.json └── torch_utils.py /.amlignore: -------------------------------------------------------------------------------- 1 | #### 2 | # The first part of this file is a copy of .gitignore. The second part holds additions relevant for Azure ML only. 3 | # FIRST PART 4 | #### 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # Python cache 161 | *.pyc 162 | __pycache__ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | .idea/ 170 | 171 | # Mlflow outputs 172 | mlruns/ 173 | 174 | #### 175 | # The first part of this file is a copy of .gitignore. The second part holds additions relevant for Azure ML only. 176 | # SECOND PART 177 | #### 178 | /sample_data 179 | /.git 180 | /css/css_with_conformer/utils/*.scp 181 | # artifacts includes various by products such as: intermediate module outputs, downloaded data, models. 182 | /artifacts 183 | /outputs -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/python 3 | { 4 | "name": "NOTSOFAR1", 5 | "image": "mcr.microsoft.com/devcontainers/python:1-3.10-bookworm", 6 | "features": { 7 | "ghcr.io/devcontainers/features/nvidia-cuda:1": { 8 | "installCudnn": true, 9 | "cudaVersion": "11.7", 10 | "cudnnVersion": "8.5.0.96" 11 | }, 12 | "ghcr.io/devcontainers/features/azure-cli:1": {} 13 | }, 14 | 15 | "hostRequirements": { 16 | "gpu": "optional" 17 | }, 18 | 19 | "runArgs": [ 20 | "--gpus", 21 | "all" 22 | ], 23 | 24 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 25 | // "forwardPorts": [], 26 | 27 | "postCreateCommand": "sudo apt-get update && sudo apt-get install -y ffmpeg && python3 -m pip install --upgrade pip && pip3 install --user -r requirements.txt" 28 | 29 | // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. 30 | // "remoteUser": "root" 31 | } 32 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Declare files that will always have LF line endings on checkout. This is especially important for *.sh files, since 2 | # if CRLF is written into these files, they may cause errors when invoked (e.g. "not found"). 3 | *.sh text eol=lf 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Python cache 156 | *.pyc 157 | __pycache__ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | .idea/ 165 | 166 | # artifacts includes various by products such as: intermediate module outputs, downloaded data, models. 167 | /artifacts 168 | /outputs 169 | 170 | # Mlflow outputs 171 | mlruns/ 172 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribute to *NOTSOFAR1* 2 | Our repository is read-only, so there are a few steps required to push your code into it: 3 | 4 | 1. **Fork the Repository**: Users who want to contribute should fork the repository to their own GitHub account. 5 | 1. **Make Changes**: Make changes to the forked repository. 6 | 1. **Create a Pull Request**: After making changes, create a PR to propose your changes to *NOTSOFAR1*. Do this by: 7 | - Navigating to the forked repository. 8 | - Clicking on the "Pull Requests" tab, and then clicking the "New Pull Request" button. 9 | - You will then be able to select **your forked repository** and the branch you want to merge changes into `NOTSOFAR1/main`. 10 | 1. **Submit the PR**: Fill out the PR form, providing details about the changes you have made. Once ready, submit the PR. 11 | 1. **Review and Merge**: One of the *NOTSOFAR1* team members will review the PR, provide feedback, and eventually merge it into the main repository if the changes are deemed appropriate. 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Base 2 | FROM mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04:latest as base 3 | 4 | # The base image comes with Python 3.10.13 preinstalled. Hence, no need to install it here. 5 | 6 | # Argument for specifying requirements file, default is "requirements.txt" 7 | ARG REQUIREMENTS_FILE=requirements.txt 8 | 9 | # Libs 10 | RUN apt-get -y update && \ 11 | apt-get install -y software-properties-common ffmpeg libportaudio2 \ 12 | libasound-dev git git-lfs zlib1g-dev libreadline-dev \ 13 | libncursesw5-dev libnss3-dev libssl-dev libsqlite3-dev tk-dev libgdbm-dev libc6-dev libbz2-dev \ 14 | libsndfile1 libsnappy-dev openmpi-bin graphviz libsm6 libxext6 libxrender-dev make build-essential \ 15 | curl llvm libncurses5-dev xz-utils libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev 16 | 17 | RUN wget http://es.archive.ubuntu.com/ubuntu/pool/main/libf/libffi/libffi7_3.3-4_amd64.deb 18 | RUN dpkg -i libffi7_3.3-4_amd64.deb 19 | 20 | # A few tools to aid debugging 21 | RUN apt-get -y install smem dstat man less screen 22 | RUN pip install ps_mem ipython 23 | 24 | # Virtualenv 25 | RUN pip install virtualenv 26 | 27 | # dependencies for running the inference pipeline packages 28 | RUN python -m pip install --upgrade pip 29 | RUN pip install --upgrade setuptools wheel Cython fasttext-wheel 30 | RUN apt-get install python3.10-dev ffmpeg build-essential 31 | 32 | # Packages 33 | ARG CACHE_BUST 34 | COPY *requirements.txt /home/ 35 | RUN pip install -r /home/${REQUIREMENTS_FILE} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /Transparency_FAQ.md: -------------------------------------------------------------------------------- 1 | # What does the repository contain? 2 | 3 | This repository contains the NOTSOFAR-1 baseline system for speaker attributed distant meeting transcription. 4 | It includes a CSS (continuous speech separation model) model trained on NOTOSFAR's simulated training dataset, along with CSS inference, diarization, and speech recognition modules. 5 | The repository also contains inference pipeline code, data downloading and processing code, and code to measure word-error rate-based metrics. 6 | 7 | # What are these components’ intended use(s)? 8 | 9 | The scientific community will use this system as a baseline to perform research on distant speaker attributed speech recognition. They will be able to use their own data and methods and extend this baseline to improve it. Then, the evaluation framework could be used to assess the level of improvement their models have achieved. 10 | 11 | # What are the limitations of these components? 12 | 13 | 1. The system components are not fine-tuned to be considered as state-of-the-art candidates for speech separation, diarization, and recognition. We only provide a baseline approach. 14 | 2. The inference code we provide is compatible only with the original models and may not be compatible after any modifications is applied to the models. 15 | 3. The inference code assumes a certain data structure as documented inside the code and may not work properly if one decides to utilize the inference code with data of a different structure. 16 | 4. The evaluation framework gives a limited depth of analysis on the performance of the model. There may be low correlation between the results of the evaluation metrics and the subjective assessment of the user on what they consider a good performance. 17 | 5. The existence of environmental noises with high energy levels may cause performance degredation that has not been quantified in our experiments. 18 | 6. The models we provide are intended to be used with English speakers, and may present different performance evaluation depending on the English accent and dialects. We have not evaluated it regarding other languages other than English. 19 | -------------------------------------------------------------------------------- /asr/asr.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import pandas as pd 6 | import whisper 7 | from tqdm import tqdm 8 | 9 | from utils.logging_def import get_logger 10 | from utils.text_norm_whisper_like import get_txt_norm 11 | 12 | _LOG = get_logger('asr') 13 | 14 | 15 | @dataclass 16 | class WhisperAsrCfg: 17 | model_name: str = 'large-v2' # use 'large-v2' for experiments, use 'tiny' for fast debugging 18 | language: Optional[str] = 'en' # language that the speech is in (if None, whisper runs language ID) 19 | word_level_time_stamps: bool = True 20 | beam_size: Optional[int] = 5 21 | hallucination_silence_threshold: Optional[float] = 2. 22 | 23 | def text_normalizer(self): 24 | return get_txt_norm("chime8") 25 | 26 | def assert_valid(self): 27 | assert self.model_name in ['tiny.en', 'tiny', 'base.en', 'base', 'small.en', 'small', 'medium.en', 28 | 'medium', 'large-v1', 'large-v2', 'large-v3', 'large'] 29 | 30 | 31 | def asr_inference(out_dir: str, session: pd.Series, cfg: WhisperAsrCfg, fetch_from_cache: bool): 32 | """ 33 | Applies automatic speech recognition using Whisper - an ASR model by OpenAI. 34 | 35 | Args: 36 | out_dir: the outputs per module are saved to out_dir/{module_name}/{session_id}. 37 | session: Row representing session to evaluate. 38 | cfg: Specifies Whisper's configuration (paths, model parameters, etc). 39 | fetch_from_cache: If True, returns the cached results if they exist. Otherwise, runs the inference. 40 | Returns: 41 | segments_df: a dataframe of transcribed segments returned by ASR with the following columns: 42 | 'start_time': start time of the segment in seconds. 43 | 'end_time': end time of the segment in seconds. 44 | 'text': the text of the segment. 45 | 'word_timing': a list of [word, start, end] lists. 46 | 'meeting_id': the meeting id. 47 | 'session_id': the session id. 48 | 'wav_file_name': the name of the wav file that the segment was transcribed from. 49 | """ 50 | _LOG.info('Running ASR') 51 | cfg.assert_valid() 52 | decode_options = dict(language=cfg.language, 53 | word_timestamps=cfg.word_level_time_stamps, 54 | beam_size=cfg.beam_size, 55 | hallucination_silence_threshold=cfg.hallucination_silence_threshold) 56 | transcribe_options = dict(task="transcribe", **decode_options) 57 | 58 | wav_files = session.sep_wav_file_names 59 | assert isinstance(wav_files, list) 60 | 61 | out_file = Path(out_dir) / 'asr' / session.session_id / cfg.model_name / "all_segments_df.pkl" 62 | 63 | if fetch_from_cache and out_file.exists(): 64 | _LOG.info(f'Loading ASR results from {out_file}') 65 | all_segments_df = pd.read_pickle(out_file) 66 | return all_segments_df 67 | 68 | _LOG.info(f'Loading Whisper model: {cfg.model_name}') 69 | model = whisper.load_model(cfg.model_name) 70 | 71 | _LOG.info(f'Running ASR on {len(wav_files)} streams') 72 | segments_dfs = [] 73 | for wav_file in tqdm(wav_files, desc='running ASR'): 74 | results = model.transcribe(str(wav_file), **transcribe_options) 75 | if len(results['segments']) == 0: 76 | _LOG.warning(f'No segments returned for {wav_file}') 77 | continue 78 | raw_seg_df = pd.DataFrame(results['segments']) 79 | 80 | # each entry in this column is a list of [word, start, end] lists 81 | word_start_end = raw_seg_df['words'].apply(lambda x: [[w['word'], w['start'], w['end']] for w in x]) 82 | 83 | segments_df = pd.DataFrame( 84 | {'start_time': raw_seg_df['start'], 85 | 'end_time': raw_seg_df['end'], 86 | 'text': raw_seg_df['text'], 87 | 'word_timing': word_start_end}) 88 | 89 | segments_df['meeting_id'] = session.meeting_id 90 | segments_df['session_id'] = session.session_id 91 | segments_df['wav_file_name'] = wav_file 92 | 93 | segments_dfs.append(segments_df) 94 | 95 | all_segments_df = pd.concat(segments_dfs, ignore_index=True) 96 | 97 | out_file.parent.mkdir(parents=True, exist_ok=True) 98 | all_segments_df.to_pickle(out_file) 99 | _LOG.info(f'ASR results saved to {out_file}') 100 | 101 | return all_segments_df 102 | 103 | -------------------------------------------------------------------------------- /configs/conf_demo.yaml: -------------------------------------------------------------------------------- 1 | # * Key-value pairs that do not appear here will be set to default values defined in the dataclasses in utils/conf.py. 2 | # * Key names and value types will be verified. 3 | css: 4 | lr: 0.017 5 | -------------------------------------------------------------------------------- /configs/inference/debug_inference.yaml: -------------------------------------------------------------------------------- 1 | # * Key-value pairs that do not appear here will be set to default values defined in the dataclasses. 2 | # * Key names and value types will be verified. 3 | 4 | asr: 5 | model_name: 'tiny' 6 | 7 | css: 8 | segment_size_sec: 3. 9 | hop_size_sec: 1.5 10 | device: "cuda:0" 11 | show_progressbar: True 12 | slice_audio_for_debug: False 13 | mc_mvdr: True 14 | mc_mask_floor_db: 0. 15 | sc_mask_floor_db: -inf 16 | activity_th: 0.3 17 | 18 | diarization: 19 | method: 'word_nmesc' # choose from "word_nmesc", "nmesc" and "nmesc_msdd" 20 | min_embedding_windows: [3.0,2.5,2.0,1.5,1.0,0.5] 21 | embedding_model_name: "titanet_large" 22 | msdd_model_name: "diar_msdd_telephonic" 23 | # vad_model_name: "vad_telephony_marblenet" # for 8kHz telephone 24 | vad_model_name: "vad_multilingual_marblenet" # for 16kHz 25 | 26 | 27 | 28 | ## one SC session: 29 | #session_query: 'device_name == "plaza_0" and is_mc == False and meeting_id == "MTG_30830"' 30 | 31 | # Plaza SC sessions that are not read meetings: 32 | #session_query: 'device_name == "plaza_0" and is_mc == False and ~MtgType.str.startswith("read")' 33 | 34 | # all SC sessions that are not read meetings: 35 | #session_query: 'is_mc == False and ~MtgType.str.startswith("read")' 36 | 37 | # ------Multi-Channel------ 38 | # all MC sessions that are not read meetings: 39 | #session_query: 'is_mc == True and ~MtgType.str.startswith("read")' 40 | 41 | # one MC session: 42 | #session_query: 'is_mc == True and meeting_id == "MTG_30830" and device_name == "rockfall_0"' 43 | 44 | # all sessions: 45 | session_query: null # NOTE: this may be overridden elsewhere 46 | -------------------------------------------------------------------------------- /configs/inference/diarization/nemo/diar_infer_general.yaml: -------------------------------------------------------------------------------- 1 | # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. 2 | # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. 3 | # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. 4 | # The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set. 5 | # An example line in an input manifest file (`.json` format): 6 | # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} 7 | name: &name "ClusterDiarizer" 8 | 9 | num_workers: 1 10 | sample_rate: 16000 11 | batch_size: 32 12 | device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) 13 | verbose: True # enable additional logging 14 | 15 | diarizer: 16 | manifest_filepath: ??? 17 | out_dir: ??? 18 | oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps 19 | collar: 0.25 # Collar value for scoring 20 | ignore_overlap: True # Consider or ignore overlap segments while scoring 21 | 22 | vad: 23 | model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name 24 | external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set 25 | 26 | parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets 27 | window_length_in_sec: 0.63 # Window length in sec for VAD context input 28 | shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction 29 | smoothing: False # False or type of smoothing method (eg: median) 30 | overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter 31 | onset: 0.5 # Onset threshold for detecting the beginning and end of a speech 32 | offset: 0.3 # Offset threshold for detecting the end of a speech 33 | pad_onset: 0.2 # Adding durations before each speech segment 34 | pad_offset: 0.2 # Adding durations after each speech segment 35 | min_duration_on: 0.5 # Threshold for small non_speech deletion 36 | min_duration_off: 0.5 # Threshold for short speech segment deletion 37 | filter_speech_first: True 38 | 39 | speaker_embeddings: 40 | model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) 41 | parameters: 42 | window_length_in_sec: [1.9,1.2,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] 43 | shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] 44 | multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] 45 | save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. 46 | 47 | clustering: 48 | parameters: 49 | oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. 50 | max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. 51 | enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. 52 | max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. 53 | sparse_search_volume: 10 # The higher the number, the more values will be examined with more time. 54 | maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. 55 | 56 | msdd_model: 57 | model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) 58 | parameters: 59 | use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. 60 | infer_batch_size: 25 # Batch size for MSDD inference. 61 | sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. 62 | seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. 63 | split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. 64 | diar_window_length: 50 # The length of split short sequence when split_infer is True. 65 | overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. 66 | 67 | asr: 68 | model_path: null # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. 69 | parameters: 70 | asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. 71 | asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. 72 | asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. 73 | decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. 74 | word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. 75 | word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. 76 | fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. 77 | colored_text: False # If True, use colored text to distinguish speakers in the output transcript. 78 | print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. 79 | break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) 80 | 81 | ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) 82 | pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. 83 | beam_width: 32 84 | alpha: 0.5 85 | beta: 2.5 86 | 87 | realigning_lm_parameters: # Experimental feature 88 | arpa_language_model: null # Provide a KenLM language model in .arpa format. 89 | min_number_of_words: 3 # Min number of words for the left context. 90 | max_number_of_words: 10 # Max number of words for the right context. 91 | logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. 92 | 93 | -------------------------------------------------------------------------------- /configs/inference/diarization/nemo/diar_infer_meeting.yaml: -------------------------------------------------------------------------------- 1 | # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. 2 | # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. 3 | # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. 4 | # The configurations in this YAML file is suitable for 3~5 speakers participating in a meeting and may not show the best performance on other types of dialogues. 5 | # An example line in an input manifest file (`.json` format): 6 | # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} 7 | name: &name "ClusterDiarizer" 8 | 9 | num_workers: 1 10 | sample_rate: 16000 11 | batch_size: 32 # use smaller batch size to avoid CUDA out of memory error 12 | device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) 13 | verbose: True # enable additional logging 14 | 15 | diarizer: 16 | manifest_filepath: ??? 17 | out_dir: ??? 18 | oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps 19 | collar: 0.25 # Collar value for scoring 20 | ignore_overlap: True # Consider or ignore overlap segments while scoring 21 | 22 | vad: 23 | model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name 24 | external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set 25 | 26 | parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) 27 | window_length_in_sec: 0.63 # Window length in sec for VAD context input 28 | shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction 29 | smoothing: False # False or type of smoothing method (eg: median) 30 | overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter 31 | onset: 0.9 # Onset threshold for detecting the beginning and end of a speech 32 | offset: 0.5 # Offset threshold for detecting the end of a speech 33 | pad_onset: 0 # Adding durations before each speech segment 34 | pad_offset: 0 # Adding durations after each speech segment 35 | min_duration_on: 0 # Threshold for small non_speech deletion 36 | min_duration_off: 0.6 # Threshold for short speech segment deletion 37 | filter_speech_first: True 38 | 39 | speaker_embeddings: 40 | model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) 41 | parameters: 42 | window_length_in_sec: [3.0,2.5,2.0,1.5,1.0,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] 43 | shift_length_in_sec: [1.5,1.25,1.0,0.75,0.5,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] 44 | multiscale_weights: [1,1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] 45 | save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. 46 | 47 | clustering: 48 | parameters: 49 | oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. 50 | max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. 51 | enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. 52 | max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. 53 | sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. 54 | maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. 55 | 56 | msdd_model: 57 | model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) 58 | parameters: 59 | use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. 60 | infer_batch_size: 25 # Batch size for MSDD inference. 61 | sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. 62 | seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. 63 | split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. 64 | diar_window_length: 50 # The length of split short sequence when split_infer is True. 65 | overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. 66 | 67 | asr: 68 | model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. 69 | parameters: 70 | asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. 71 | asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. 72 | asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. 73 | decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. 74 | word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. 75 | word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. 76 | fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. 77 | colored_text: False # If True, use colored text to distinguish speakers in the output transcript. 78 | print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. 79 | break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) 80 | 81 | ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) 82 | pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. 83 | beam_width: 32 84 | alpha: 0.5 85 | beta: 2.5 86 | 87 | realigning_lm_parameters: # Experimental feature 88 | arpa_language_model: null # Provide a KenLM language model in .arpa format. 89 | min_number_of_words: 3 # Min number of words for the left context. 90 | max_number_of_words: 10 # Max number of words for the right context. 91 | logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. 92 | 93 | -------------------------------------------------------------------------------- /configs/inference/diarization/nemo/diar_infer_telephonic.yaml: -------------------------------------------------------------------------------- 1 | # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. 2 | # The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. 3 | # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. 4 | # The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues. 5 | # An example line in an input manifest file (`.json` format): 6 | # {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} 7 | name: &name "ClusterDiarizer" 8 | 9 | num_workers: 1 10 | sample_rate: 16000 11 | batch_size: 32 12 | device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) 13 | verbose: True # enable additional logging 14 | 15 | diarizer: 16 | manifest_filepath: ??? 17 | out_dir: ??? 18 | oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps 19 | collar: 0.25 # Collar value for scoring 20 | ignore_overlap: True # Consider or ignore overlap segments while scoring 21 | 22 | vad: 23 | model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name 24 | external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set 25 | 26 | parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) 27 | window_length_in_sec: 0.15 # Window length in sec for VAD context input 28 | shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction 29 | smoothing: "median" # False or type of smoothing method (eg: median) 30 | overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter 31 | onset: 0.1 # Onset threshold for detecting the beginning and end of a speech 32 | offset: 0.1 # Offset threshold for detecting the end of a speech 33 | pad_onset: 0.1 # Adding durations before each speech segment 34 | pad_offset: 0 # Adding durations after each speech segment 35 | min_duration_on: 0 # Threshold for small non_speech deletion 36 | min_duration_off: 0.2 # Threshold for short speech segment deletion 37 | filter_speech_first: True 38 | 39 | speaker_embeddings: 40 | model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) 41 | parameters: 42 | window_length_in_sec: [1.5,1.25,1.0,0.75,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] 43 | shift_length_in_sec: [0.75,0.625,0.5,0.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] 44 | multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] 45 | save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. 46 | 47 | clustering: 48 | parameters: 49 | oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. 50 | max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. 51 | enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. 52 | max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. 53 | sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. 54 | maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. 55 | 56 | msdd_model: 57 | model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) 58 | parameters: 59 | use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. 60 | infer_batch_size: 25 # Batch size for MSDD inference. 61 | sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. 62 | seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. 63 | split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. 64 | diar_window_length: 50 # The length of split short sequence when split_infer is True. 65 | overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. 66 | 67 | asr: 68 | model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. 69 | parameters: 70 | asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. 71 | asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. 72 | asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. 73 | decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. 74 | word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. 75 | word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. 76 | fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. 77 | colored_text: False # If True, use colored text to distinguish speakers in the output transcript. 78 | print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. 79 | break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) 80 | 81 | ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) 82 | pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. 83 | beam_width: 32 84 | alpha: 0.5 85 | beta: 2.5 86 | 87 | realigning_lm_parameters: # Experimental feature 88 | arpa_language_model: null # Provide a KenLM language model in .arpa format. 89 | min_number_of_words: 3 # Min number of words for the left context. 90 | max_number_of_words: 10 # Max number of words for the right context. 91 | logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. 92 | 93 | -------------------------------------------------------------------------------- /configs/inference/inference_v1.yaml: -------------------------------------------------------------------------------- 1 | # * Key-value pairs that do not appear here will be set to default values defined in the dataclasses. 2 | # * Key names and value types will be verified. 3 | 4 | asr: 5 | model_name: 'large-v3' 6 | 7 | css: 8 | segment_size_sec: 3. 9 | hop_size_sec: 1.5 10 | device: "cuda:0" 11 | show_progressbar: True 12 | slice_audio_for_debug: False 13 | pass_through_ch0: False 14 | mc_mvdr: True 15 | mc_mask_floor_db: 0. # for MC, MVDR without any direct masking worked best 16 | sc_mask_floor_db: -inf # for SC, direct masking without floor worked best 17 | activity_th: 0.3 18 | 19 | diarization: 20 | method: 'word_nmesc' # choose from "word_nmesc", "nmesc" and "nmesc_msdd" 21 | min_embedding_windows: [3.0,2.5,2.0,1.5,1.0,0.5] 22 | embedding_model_name: "titanet_large" 23 | msdd_model_name: "diar_msdd_telephonic" 24 | # vad_model_name: "vad_telephony_marblenet" # for 8kHz telephone 25 | vad_model_name: "vad_multilingual_marblenet" # for 16kHz 26 | apply_deduplication: true 27 | 28 | scoring: 29 | save_visualizations: False 30 | 31 | ## one MC session: 32 | #session_query: 'device_name == "plaza_0" and is_mc == True and meeting_id == "MTG_30891"' 33 | 34 | # one SC session: 35 | #session_query: 'device_name == "plaza_0" and is_mc == False and meeting_id == "MTG_30891"' 36 | 37 | # all sessions: 38 | session_query: null 39 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v0.51_mc.yaml: -------------------------------------------------------------------------------- 1 | # Note there's newer and better data. Do not download v1.2 2 | train_dir: ./v1.2/100hrs/train 3 | val_dir: ./v1.2/100hrs/val 4 | out_dir: ./ 5 | 6 | train_set_cfg: 7 | sample_frac: 1.0 8 | max_urls: null # null means no limit 9 | val_set_cfg: 10 | sample_frac: 1.0 11 | max_urls: null # null means no limit 12 | 13 | clip_gt_to_mixture: True 14 | 15 | log_params_mlflow: True 16 | log_metrics_mlflow: True 17 | 18 | scheduler_step_every: [1, iterations] 19 | scheduler_name: step_lr 20 | scheduler_linear_warmup_decay_cfg: 21 | warmup: 10000 22 | decay: 260000 23 | scheduler_step_lr_cfg: 24 | step_size: 1 25 | gamma: 1.0 # no decay 26 | 27 | stop_after: [260000, iterations] 28 | eval_every: [1000, iterations] 29 | save_every: [1000, iterations] 30 | 31 | global_batch_size: 256 32 | learning_rate: 1e-5 33 | weight_decay: 1e-2 # according to the paper set to 1e-2 34 | 35 | # Large model per CSS with Conformer definition 36 | conformer_css_cfg: 37 | nnet_conf: 38 | conformer_conf: 39 | attention_dim: 512 # default 256 40 | attention_heads: 8 # default 4 41 | num_blocks: 18 # default 16 42 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v0.51_sc.yaml: -------------------------------------------------------------------------------- 1 | # Note there's newer and better data. Do not download v1.2 2 | train_dir: ./v1.2/100hrs/train 3 | val_dir: ./v1.2/100hrs/val 4 | out_dir: ./ 5 | 6 | single_channel: True 7 | 8 | train_set_cfg: 9 | sample_frac: 1.0 10 | max_urls: null # null means no limit 11 | val_set_cfg: 12 | sample_frac: 1.0 13 | max_urls: null # null means no limit 14 | 15 | clip_gt_to_mixture: True 16 | 17 | log_params_mlflow: True 18 | log_metrics_mlflow: True 19 | 20 | scheduler_step_every: [1, iterations] 21 | scheduler_name: step_lr 22 | scheduler_linear_warmup_decay_cfg: 23 | warmup: 10000 24 | decay: 260000 25 | scheduler_step_lr_cfg: 26 | step_size: 1 27 | gamma: 1.0 # no decay 28 | 29 | stop_after: [260000, iterations] 30 | eval_every: [1000, iterations] 31 | save_every: [1000, iterations] 32 | 33 | global_batch_size: 256 34 | learning_rate: 1e-5 35 | weight_decay: 1e-2 # according to the paper set to 1e-2 36 | 37 | # Large model per CSS with Conformer definition 38 | conformer_css_cfg: 39 | extractor_conf: 40 | ipd_index: '' # For MC '1,0;2,0;3,0;4,0;5,0;6,0'. For SC ''. 41 | nnet_conf: 42 | conformer_conf: 43 | attention_dim: 512 # default 256 44 | attention_heads: 8 # default 4 45 | num_blocks: 18 # default 16 46 | in_features: 257 # For MC 1799. For SC 257. 47 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v0.5_mc.yaml: -------------------------------------------------------------------------------- 1 | # Note there's newer and better data. Do not download v1.2 2 | train_dir: ./v1.2/100hrs/train 3 | val_dir: ./v1.2/100hrs/val 4 | out_dir: ./ 5 | 6 | train_set_cfg: 7 | sample_frac: 1.0 8 | max_urls: null # null means no limit 9 | val_set_cfg: 10 | sample_frac: 1.0 11 | max_urls: null # null means no limit 12 | 13 | # This model was trained with clip_gt_to_mixture=False, but we recommend to set it to True. 14 | clip_gt_to_mixture: False 15 | 16 | log_params_mlflow: True 17 | log_metrics_mlflow: True 18 | 19 | scheduler_step_every: [1, iterations] 20 | scheduler_name: step_lr 21 | scheduler_linear_warmup_decay_cfg: 22 | warmup: 10000 23 | decay: 260000 24 | scheduler_step_lr_cfg: 25 | step_size: 1 26 | gamma: 1.0 # no decay 27 | 28 | stop_after: [260000, iterations] 29 | eval_every: [1000, iterations] 30 | save_every: [1000, iterations] 31 | 32 | global_batch_size: 256 33 | learning_rate: 1e-5 34 | weight_decay: 1e-2 # according to the paper set to 1e-2 35 | 36 | # Large model per CSS with Conformer definition 37 | conformer_css_cfg: 38 | nnet_conf: 39 | conformer_conf: 40 | attention_dim: 512 # default 256 41 | attention_heads: 8 # default 4 42 | num_blocks: 18 # default 16 43 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v0.5_sc.yaml: -------------------------------------------------------------------------------- 1 | # Note there's newer and better data. Do not download v1.2 2 | train_dir: ./v1.2/100hrs/train 3 | val_dir: ./v1.2/100hrs/val 4 | out_dir: ./ 5 | 6 | single_channel: True 7 | 8 | train_set_cfg: 9 | sample_frac: 1.0 10 | max_urls: null # null means no limit 11 | val_set_cfg: 12 | sample_frac: 1.0 13 | max_urls: null # null means no limit 14 | 15 | # This model was trained with clip_gt_to_mixture=False, but we recommend to set it to True. 16 | clip_gt_to_mixture: False 17 | 18 | log_params_mlflow: True 19 | log_metrics_mlflow: True 20 | 21 | scheduler_step_every: [1, iterations] 22 | scheduler_name: step_lr 23 | scheduler_linear_warmup_decay_cfg: 24 | warmup: 10000 25 | decay: 260000 26 | scheduler_step_lr_cfg: 27 | step_size: 1 28 | gamma: 1.0 # no decay 29 | 30 | stop_after: [260000, iterations] 31 | eval_every: [1000, iterations] 32 | save_every: [1000, iterations] 33 | 34 | global_batch_size: 256 35 | learning_rate: 1e-5 36 | weight_decay: 1e-2 # according to the paper set to 1e-2 37 | 38 | # Large model per CSS with Conformer definition 39 | conformer_css_cfg: 40 | extractor_conf: 41 | ipd_index: '' # For MC '1,0;2,0;3,0;4,0;5,0;6,0'. For SC ''. 42 | nnet_conf: 43 | conformer_conf: 44 | attention_dim: 512 # default 256 45 | attention_heads: 8 # default 4 46 | num_blocks: 18 # default 16 47 | in_features: 257 # For MC 1799. For SC 257. 48 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v1.0_mc.yaml: -------------------------------------------------------------------------------- 1 | # Note that this model uses a mask-based loss, in contrast to a masked magnitude loss that was used in conformer_v0.5. 2 | train_dir: ./v1.5/200hrs/train 3 | val_dir: ./v1.5/200hrs/val 4 | out_dir: ./ 5 | 6 | train_set_cfg: 7 | sample_frac: 1.0 8 | max_urls: null # null means no limit 9 | val_set_cfg: 10 | sample_frac: 1.0 11 | max_urls: null # null means no limit 12 | 13 | calc_side_info: True 14 | log_params_mlflow: True 15 | log_metrics_mlflow: True 16 | 17 | scheduler_step_every: [1, iterations] 18 | scheduler_name: step_lr 19 | scheduler_step_lr_cfg: 20 | # Fixed LR 21 | step_size: 1 22 | gamma: 1.0 23 | 24 | stop_after: [520000, iterations] 25 | eval_every: [1000, iterations] 26 | save_every: [1000, iterations] 27 | 28 | loss_name: 'mask' 29 | base_loss_name: 'l1' 30 | 31 | global_batch_size: 256 32 | learning_rate: 1e-4 33 | weight_decay: 1e-2 # according to the paper set to 1e-2 34 | 35 | # Large model per CSS with Conformer definition 36 | conformer_css_cfg: 37 | nnet_conf: 38 | conformer_conf: 39 | attention_dim: 512 # default 256 40 | attention_heads: 8 # default 4 41 | num_blocks: 18 # default 16 42 | dropout_rate: 0.0 # New! The default was 0.1. 43 | -------------------------------------------------------------------------------- /configs/train_css/local/conformer_v1.0_sc.yaml: -------------------------------------------------------------------------------- 1 | # Note that this model uses a mask-based loss, in contrast to a masked magnitude loss that was used in conformer_v0.5. 2 | train_dir: ./v1.5/1000hrs/train 3 | val_dir: ./v1.5/200hrs/val # enough for validation 4 | out_dir: ./ 5 | 6 | single_channel: True 7 | 8 | train_set_cfg: 9 | sample_frac: 1.0 10 | max_urls: 640 # out of a total of 800. Subsample to fit local storage and avoid cache misses. 11 | val_set_cfg: 12 | sample_frac: 1.0 13 | max_urls: null # null means no limit 14 | 15 | calc_side_info: True 16 | log_params_mlflow: True 17 | log_metrics_mlflow: True 18 | 19 | scheduler_step_every: [1, iterations] 20 | scheduler_name: linear_warmup_decay 21 | scheduler_linear_warmup_decay_cfg: 22 | warmup: 10000 23 | decay: 520000 24 | 25 | stop_after: [520000, iterations] 26 | eval_every: [1000, iterations] 27 | save_every: [1000, iterations] 28 | 29 | loss_name: 'mask' 30 | base_loss_name: 'l1' 31 | 32 | global_batch_size: 256 33 | learning_rate: 1e-4 34 | weight_decay: 1e-2 # according to the paper set to 1e-2 35 | 36 | # Large model per CSS with Conformer definition 37 | conformer_css_cfg: 38 | extractor_conf: 39 | ipd_index: '' # For MC '1,0;2,0;3,0;4,0;5,0;6,0'. For SC ''. 40 | nnet_conf: 41 | conformer_conf: 42 | attention_dim: 512 # default 256 43 | attention_heads: 8 # default 4 44 | num_blocks: 18 # default 16 45 | dropout_rate: 0.0 # New! The default was 0.1. 46 | in_features: 257 # For MC 1799. For SC 257. 47 | -------------------------------------------------------------------------------- /configs/train_css/local/debug_mc.yaml: -------------------------------------------------------------------------------- 1 | # All paths are relative to the project root. 2 | train_dir: sample_data/css_train_set 3 | val_dir: sample_data/css_train_set # same as train_dir for debug purposes only! 4 | out_dir: outputs/css_train 5 | 6 | train_set_cfg: 7 | sample_frac: 1.0 8 | max_urls: 2 # null means no limit 9 | val_set_cfg: 10 | sample_frac: 1.0 11 | max_urls: 2 # null means no limit 12 | 13 | clip_gt_to_mixture: True 14 | 15 | log_params_mlflow: False 16 | log_metrics_mlflow: False 17 | 18 | scheduler_step_every: [1, iterations] 19 | scheduler_name: linear_warmup_decay 20 | scheduler_linear_warmup_decay_cfg: 21 | warmup: 30 22 | decay: 60 23 | stop_after: [90, iterations] 24 | 25 | eval_every: [30, iterations] 26 | 27 | global_batch_size: 32 28 | learning_rate: 1e-4 29 | weight_decay: 1e-2 # according to the paper set to 1e-2 30 | 31 | is_debug: True 32 | -------------------------------------------------------------------------------- /configs/train_css/local/debug_sc.yaml: -------------------------------------------------------------------------------- 1 | # All paths are relative to the project root. 2 | train_dir: sample_data/css_train_set 3 | val_dir: sample_data/css_train_set # same as train_dir for debug purposes only! 4 | out_dir: outputs/css_train 5 | 6 | single_channel: True 7 | conformer_css_cfg: 8 | extractor_conf: 9 | ipd_index: '' # instead of 1,0;2,0;3,0;4,0;5,0;6,0 for MC 10 | nnet_conf: 11 | in_features: 257 # instead of 1799 for MC 12 | 13 | train_set_cfg: 14 | sample_frac: 1.0 15 | max_urls: 2 # null means no limit 16 | val_set_cfg: 17 | sample_frac: 1.0 18 | max_urls: 2 # null means no limit 19 | 20 | clip_gt_to_mixture: True 21 | 22 | log_params_mlflow: False 23 | log_metrics_mlflow: False 24 | 25 | scheduler_step_every: [1, iterations] 26 | scheduler_name: linear_warmup_decay 27 | scheduler_linear_warmup_decay_cfg: 28 | warmup: 30 29 | decay: 60 30 | stop_after: [90, iterations] 31 | 32 | eval_every: [30, iterations] 33 | 34 | global_batch_size: 32 35 | learning_rate: 1e-4 36 | weight_decay: 1e-2 # according to the paper set to 1e-2 37 | 38 | is_debug: true -------------------------------------------------------------------------------- /css/css_with_conformer/README.md: -------------------------------------------------------------------------------- 1 | The code under this directory is mostly a copy of "CSS with Conformer" from the original repo at the URL below. 2 | Some extentions were made when adopting to NOTSOFAR. 3 | 4 | We didn't copy the README.md file from the original repo because it contains Shared Access Signatures (SAS) that 5 | are considered secrets by some version control systems. One can find the original README.md file at the URL below. 6 | https://github.com/Sanyuan-Chen/CSS_with_Conformer/blob/master/README.md 7 | -------------------------------------------------------------------------------- /css/css_with_conformer/executor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/css/css_with_conformer/executor/__init__.py -------------------------------------------------------------------------------- /css/css_with_conformer/executor/executor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | from pathlib import Path 6 | from .feature import FeatureExtractor 7 | 8 | 9 | class Executor(nn.Module): 10 | """ 11 | Executor is a class to handle feature extraction 12 | and forward process of the separation networks. 13 | """ 14 | def __init__(self, nnet, extractor_kwargs=None, get_mask=True): 15 | super(Executor, self).__init__() 16 | self.nnet = nnet 17 | self.extractor = FeatureExtractor( 18 | **extractor_kwargs) if extractor_kwargs else None 19 | self.frame_len = extractor_kwargs['frame_len'] if extractor_kwargs else None 20 | self.frame_hop = extractor_kwargs['frame_hop'] if extractor_kwargs else None 21 | self.get_mask = get_mask 22 | 23 | def resume(self, checkpoint): 24 | """ 25 | Resume from checkpoint 26 | """ 27 | if not Path(checkpoint).exists(): 28 | raise FileNotFoundError( 29 | f"Could not find resume checkpoint: {checkpoint}") 30 | cpt = th.load(checkpoint, map_location="cpu") 31 | self.load_state_dict(cpt["model_state_dict"]) 32 | return cpt["epoch"] 33 | 34 | def _compute_feats(self, egs): 35 | """ 36 | Compute features: N x F x T 37 | """ 38 | if not self.extractor: 39 | raise RuntimeError("self.extractor is None, " 40 | "do not need to compute features") 41 | mag, pha, f = self.extractor(**egs) 42 | return mag, pha, f 43 | 44 | def forward(self, egs): 45 | mag, pha, f = self._compute_feats(egs) 46 | out = self.nnet(f) 47 | if self.get_mask: 48 | return out 49 | else: 50 | return [self.extractor.istft(m * mag, pha) for m in out] 51 | -------------------------------------------------------------------------------- /css/css_with_conformer/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import conformer 3 | 4 | supported_nnet = { 5 | "conformer": conformer.ConformerCSS, 6 | } 7 | -------------------------------------------------------------------------------- /css/css_with_conformer/nnet/conformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """Implementation of Conformer speech separation model""" 5 | 6 | import math 7 | import numpy 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class RelativePositionalEncoding(torch.nn.Module): 13 | def __init__(self, d_model, maxlen=1000, embed_v=False): 14 | super(RelativePositionalEncoding, self).__init__() 15 | 16 | self.d_model = d_model 17 | self.maxlen = maxlen 18 | self.pe_k = torch.nn.Embedding(2*maxlen, d_model) 19 | if embed_v: 20 | self.pe_v = torch.nn.Embedding(2*maxlen, d_model) 21 | self.embed_v = embed_v 22 | 23 | def forward(self, pos_seq): 24 | pos_seq.clamp_(-self.maxlen, self.maxlen - 1) 25 | pos_seq = pos_seq + self.maxlen 26 | if self.embed_v: 27 | return self.pe_k(pos_seq), self.pe_v(pos_seq) 28 | else: 29 | return self.pe_k(pos_seq), None 30 | 31 | 32 | class MultiHeadedAttention(nn.Module): 33 | """Multi-Head Attention layer. 34 | 35 | :param int n_head: the number of head s 36 | :param int n_feat: the number of features 37 | :param float dropout_rate: dropout rate 38 | 39 | """ 40 | 41 | def __init__(self, n_head, n_feat, dropout_rate): 42 | """Construct an MultiHeadedAttention object.""" 43 | super(MultiHeadedAttention, self).__init__() 44 | assert n_feat % n_head == 0 45 | # We assume d_v always equals d_k 46 | self.d_k = n_feat // n_head 47 | self.h = n_head 48 | self.layer_norm = nn.LayerNorm(n_feat) 49 | self.linear_q = nn.Linear(n_feat, n_feat) 50 | self.linear_k = nn.Linear(n_feat, n_feat) 51 | self.linear_v = nn.Linear(n_feat, n_feat) 52 | 53 | self.linear_out = nn.Linear(n_feat, n_feat) 54 | self.attn = None 55 | self.dropout = nn.Dropout(p=dropout_rate) 56 | 57 | def forward(self, x, pos_k, mask): 58 | """Compute 'Scaled Dot Product Attention'. 59 | 60 | :param torch.Tensor mask: (batch, time1, time2) 61 | :param torch.nn.Dropout dropout: 62 | :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) 63 | weighted by the query dot key attention (batch, head, time1, time2) 64 | """ 65 | n_batch = x.size(0) 66 | x = self.layer_norm(x) 67 | q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 68 | k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 69 | v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k) 70 | q = q.transpose(1, 2) 71 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 72 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 73 | A = torch.matmul(q, k.transpose(-2, -1)) 74 | reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1) 75 | if pos_k is not None: 76 | B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) 77 | B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) 78 | scores = (A + B) / math.sqrt(self.d_k) 79 | else: 80 | scores = A / math.sqrt(self.d_k) 81 | if mask is not None: 82 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) 83 | min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 84 | scores = scores.masked_fill(mask, min_value) 85 | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) 86 | else: 87 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 88 | 89 | p_attn = self.dropout(self.attn) 90 | x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) 91 | x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) 92 | return self.dropout(self.linear_out(x)) # (batch, time1, d_model) 93 | 94 | 95 | class ConvModule(nn.Module): 96 | def __init__(self, input_dim, kernel_size, dropout_rate, causal=False): 97 | super(ConvModule, self).__init__() 98 | self.layer_norm = nn.LayerNorm(input_dim) 99 | 100 | self.pw_conv_1 = nn.Conv2d(1, 2, 1, 1, 0) 101 | self.glu_act = torch.nn.Sigmoid() 102 | self.causal = causal 103 | if causal: 104 | self.dw_conv_1d = nn.Conv1d(input_dim, input_dim, kernel_size, 1, padding=(kernel_size-1), groups=input_dim) 105 | else: 106 | self.dw_conv_1d = nn.Conv1d(input_dim, input_dim, kernel_size, 1, padding=(kernel_size-1)//2, groups=input_dim) 107 | self.BN = nn.BatchNorm1d(input_dim) 108 | self.act = nn.ReLU() 109 | self.pw_conv_2 = nn.Conv2d(1, 1, 1, 1, 0) 110 | self.dropout = nn.Dropout(dropout_rate) 111 | self.kernel_size = kernel_size 112 | 113 | def forward(self, x): 114 | x = x.unsqueeze(1) 115 | x = self.layer_norm(x) 116 | x = self.pw_conv_1(x) 117 | x = x[:, 0] * self.glu_act(x[:, 1]) 118 | x = x.permute([0, 2, 1]) 119 | x = self.dw_conv_1d(x) 120 | if self.causal: 121 | x = x[:, :, :-(self.kernel_size-1)] 122 | x = self.BN(x) 123 | x = self.act(x) 124 | x = x.unsqueeze(1).permute([0, 1, 3, 2]) 125 | x = self.pw_conv_2(x) 126 | x = self.dropout(x).squeeze(1) 127 | return x 128 | 129 | 130 | class FeedForward(nn.Module): 131 | def __init__(self, d_model, d_inner, dropout_rate): 132 | super(FeedForward, self).__init__() 133 | 134 | self.d_model = d_model 135 | self.d_inner = d_inner 136 | 137 | self.layer_norm = nn.LayerNorm(d_model) 138 | self.net = nn.Sequential( 139 | nn.Linear(d_model, d_inner), 140 | nn.ReLU(inplace=True), 141 | nn.Dropout(dropout_rate), 142 | nn.Linear(d_inner, d_model), 143 | nn.Dropout(dropout_rate) 144 | ) 145 | 146 | def forward(self, x): 147 | x = self.layer_norm(x) 148 | out = self.net(x) 149 | 150 | return out 151 | 152 | 153 | class EncoderLayer(nn.Module): 154 | """Encoder layer module. 155 | 156 | :param int d_model: attention vector size 157 | :param int n_head: number of heads 158 | :param int d_ffn: feedforward size 159 | :param int kernel_size: cnn kernal size, it must be an odd 160 | :param int dropout_rate: dropout_rate 161 | """ 162 | 163 | def __init__(self, d_model, n_head, d_ffn, kernel_size, dropout_rate, causal=False): 164 | """Construct an EncoderLayer object.""" 165 | super(EncoderLayer, self).__init__() 166 | self.feed_forward_in = FeedForward(d_model, d_ffn, dropout_rate) 167 | self.self_attn = MultiHeadedAttention(n_head, d_model, dropout_rate) 168 | self.conv = ConvModule(d_model, kernel_size, dropout_rate, causal=causal) 169 | self.feed_forward_out = FeedForward(d_model, d_ffn, dropout_rate) 170 | self.layer_norm = nn.LayerNorm(d_model) 171 | 172 | def forward(self, x, pos_k, mask): 173 | """Compute encoded features. 174 | 175 | :param torch.Tensor x: encoded source features (batch, max_time_in, size) 176 | :param torch.Tensor mask: mask for x (batch, max_time_in) 177 | :rtype: Tuple[torch.Tensor, torch.Tensor] 178 | """ 179 | x = x + 0.5 * self.feed_forward_in(x) 180 | x = x + self.self_attn(x, pos_k, mask) 181 | x = x + self.conv(x) 182 | x = x + 0.5 * self.feed_forward_out(x) 183 | 184 | out = self.layer_norm(x) 185 | 186 | return out 187 | 188 | 189 | class ConformerEncoder(nn.Module): 190 | """Conformer Encoder https://arxiv.org/abs/2005.08100 191 | """ 192 | def __init__(self, 193 | idim=257, 194 | attention_dim=256, 195 | attention_heads=4, 196 | linear_units=1024, 197 | num_blocks=16, 198 | kernel_size=33, 199 | dropout_rate=0.1, 200 | causal=False, 201 | relative_pos_emb=True 202 | ): 203 | super(ConformerEncoder, self).__init__() 204 | 205 | self.embed = torch.nn.Sequential( 206 | torch.nn.Linear(idim, attention_dim), 207 | torch.nn.LayerNorm(attention_dim), 208 | torch.nn.Dropout(dropout_rate), 209 | torch.nn.ReLU(), 210 | ) 211 | 212 | if relative_pos_emb: 213 | self.pos_emb = RelativePositionalEncoding(attention_dim // attention_heads, 1000, False) 214 | else: 215 | self.pos_emb = None 216 | 217 | self.encoders = torch.nn.Sequential(*[EncoderLayer( 218 | attention_dim, 219 | attention_heads, 220 | linear_units, 221 | kernel_size, 222 | dropout_rate, 223 | causal=causal 224 | ) for _ in range(num_blocks)]) 225 | 226 | def forward(self, xs, masks): 227 | xs = self.embed(xs) 228 | 229 | if self.pos_emb is not None: 230 | x_len = xs.shape[1] 231 | pos_seq = torch.arange(0, x_len).long().to(xs.device) 232 | pos_seq = pos_seq[:, None] - pos_seq[None, :] 233 | pos_k, _ = self.pos_emb(pos_seq) 234 | else: 235 | pos_k = None 236 | for layer in self.encoders: 237 | xs = layer(xs, pos_k, masks) 238 | 239 | return xs, masks 240 | 241 | 242 | default_encoder_conf = { 243 | "attention_dim": 256, 244 | "attention_heads": 4, 245 | "linear_units": 1024, 246 | "num_blocks": 16, 247 | "kernel_size": 33, 248 | "dropout_rate": 0.1, 249 | "relative_pos_emb": True 250 | } 251 | 252 | 253 | class ConformerCSS(nn.Module): 254 | """ 255 | Conformer speech separation model 256 | """ 257 | def __init__(self, 258 | stats_file=None, 259 | in_features=257, 260 | num_bins=257, 261 | num_spks=2, 262 | num_nois=1, 263 | conformer_conf=default_encoder_conf): 264 | super(ConformerCSS, self).__init__() 265 | 266 | # input normalization layer 267 | if stats_file is not None: 268 | stats = numpy.load(stats_file) 269 | self.input_bias = torch.from_numpy(numpy.tile(numpy.expand_dims(-stats['mean'].astype(numpy.float32), axis=0), (1, 1, 1))) 270 | self.input_scale = torch.from_numpy(numpy.tile(numpy.expand_dims(1 / numpy.sqrt(stats['variance'].astype(numpy.float32)), axis=0), (1, 1, 1))) 271 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 272 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 273 | else: 274 | self.input_bias = torch.zeros(1,1,in_features) 275 | self.input_scale = torch.ones(1,1,in_features) 276 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 277 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 278 | 279 | # Conformer Encoders 280 | self.conformer = ConformerEncoder(in_features, **conformer_conf) 281 | 282 | self.num_bins = num_bins 283 | self.num_spks = num_spks 284 | self.num_nois = num_nois 285 | self.linear = nn.Linear(conformer_conf["attention_dim"], num_bins * (num_spks + num_nois)) 286 | 287 | def forward(self, f): 288 | """ 289 | args 290 | f: N x * x T 291 | return 292 | m: [N x F x T, ...] 293 | """ 294 | # N x * x T => N x T x * 295 | f = f.transpose(1, 2) 296 | 297 | # global feature normalization 298 | f = f + self.input_bias 299 | f = f * self.input_scale 300 | 301 | f, _ = self.conformer(f, masks=None) 302 | m = self.linear(f) 303 | 304 | m = torch.sigmoid(m) 305 | 306 | # N x T x F => N x F x T 307 | m = m.transpose(1, 2) 308 | if self.num_spks > 1: 309 | m = torch.chunk(m, self.num_spks + self.num_nois, 1) 310 | return m 311 | -------------------------------------------------------------------------------- /css/css_with_conformer/separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | 4 | This module and the contents of the "css_with_conformer" folder were adapted from 5 | https://github.com/Sanyuan-Chen/CSS_with_Conformer, with some modifications. 6 | 7 | """ 8 | 9 | import yaml 10 | import argparse 11 | from pathlib import Path 12 | 13 | import torch as th 14 | import numpy as np 15 | import soundfile as sf 16 | 17 | from css.css_with_conformer.nnet import supported_nnet 18 | from css.css_with_conformer.executor.executor import Executor 19 | from css.css_with_conformer.utils.audio_util import WaveReader 20 | from css.css_with_conformer.utils.mvdr_util import make_mvdr 21 | 22 | 23 | class EgsReader(object): 24 | """ 25 | Egs reader 26 | """ 27 | def __init__(self, 28 | mix_scp, 29 | sr=16000): 30 | self.mix_reader = WaveReader(mix_scp, sr=sr) 31 | 32 | def __len__(self): 33 | return len(self.mix_reader) 34 | 35 | def __iter__(self): 36 | for key, mix in self.mix_reader: 37 | egs = dict() 38 | egs["mix"] = mix 39 | yield key, egs 40 | 41 | 42 | class Separator(object): 43 | """ 44 | A simple wrapper for speech separation 45 | """ 46 | def __init__(self, cpt_dir, get_mask=False, device='cpu'): 47 | # load executor 48 | cpt_dir = Path(cpt_dir) 49 | self.get_mask = get_mask 50 | self.executor = self._load_executor(cpt_dir) 51 | cpt_ptr = cpt_dir / "best.pt.tar" 52 | epoch = self.executor.resume(cpt_ptr.as_posix()) 53 | print(f"Load checkpoint at {cpt_dir}, on epoch {epoch}") 54 | #print(f"Nnet summary: {self.executor}") 55 | self.device = device 56 | self.executor.to(self.device) 57 | self.executor.eval() 58 | 59 | def separate(self, egs): 60 | """ 61 | Do separation 62 | """ 63 | egs["mix"] = th.from_numpy(egs["mix"][None, :]).to(self.device, non_blocking=True) 64 | with th.no_grad(): 65 | spks = self.executor(egs) 66 | spks = [s.detach().squeeze().cpu().numpy() for s in spks] 67 | return spks 68 | 69 | def _load_executor(self, cpt_dir): 70 | """ 71 | Load executor from checkpoint 72 | """ 73 | with open(cpt_dir / "train.yaml", "r") as f: 74 | conf = yaml.load(f, Loader=yaml.FullLoader) 75 | nnet_type = conf["nnet_type"] 76 | if nnet_type not in supported_nnet: 77 | raise RuntimeError(f"Unknown network type: {nnet_type}") 78 | nnet = supported_nnet[nnet_type](**conf["nnet_conf"]) 79 | executor = Executor(nnet, extractor_kwargs=conf["extractor_conf"], get_mask=self.get_mask) 80 | return executor 81 | 82 | 83 | def run(args): 84 | wav, sr = sf.read(args.wav_file, dtype='float32') 85 | 86 | # separator 87 | seperator = Separator(args.checkpoint, device_id=args.device_id, get_mask=args.mvdr) 88 | 89 | dump_dir = Path(args.dump_dir) 90 | dump_dir.mkdir(exist_ok=True, parents=True) 91 | egs = {'mix': wav[int(sr * args.start): int(sr * args.end)]} 92 | duration_sec = egs['mix'].size / sr 93 | 94 | # print(f"Start Separation " + ("w/ mvdr" if args.mvdr else "w/o mvdr")) 95 | # for key, egs in egs_reader: 96 | # print(f"Processing utterance {key}...{egs}") 97 | mixed = egs["mix"] 98 | print('mixed',mixed.shape) 99 | spks = seperator.separate(egs) 100 | print('spks',len(spks),spks[0].shape) 101 | 102 | if args.mvdr: 103 | res1, res2 = make_mvdr(spks[:2], spks[2:], np.asfortranarray(mixed.T)) 104 | spks = [res1, res2] 105 | 106 | sf.write(dump_dir / f"{duration_sec}_mix.wav", egs['mix'][0].cpu().numpy(), sr) 107 | for i, s in enumerate(spks): 108 | if i < args.num_spks: 109 | write_path = dump_dir / f"{duration_sec}_{i}.wav" 110 | print(write_path) 111 | sf.write(write_path, s * 0.9 / np.max(np.abs(s)), sr) 112 | 113 | print(f"Done processing {args.wav_file}") 114 | 115 | 116 | def run_pretrained_sc_conformer(): 117 | args = argparse.Namespace( 118 | checkpoint=r"C:\Repos\NOTSOFAR\artifacts\css_models\CSS_with_Conformer\sc\1ch_conformer_base", 119 | wav_file=r"C:\Repos\NOTSOFAR\artifacts\ch0.wav", 120 | start=28, 121 | end=32, 122 | num_spks=2, 123 | device_id=0, 124 | sr=16000, 125 | dump_dir=r"C:\Repos\NOTSOFAR\artifacts\conformer_dump", 126 | mvdr=False 127 | ) 128 | 129 | run(args) 130 | 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser( 134 | description="Command to do speech separation", 135 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 136 | parser.add_argument("--checkpoint", type=str, help="Directory of checkpoint") 137 | parser.add_argument("--wav-file", 138 | type=str, 139 | required=True, 140 | help="mixed audio wav") 141 | parser.add_argument("--start", 142 | type=float, 143 | required=False, 144 | default=None, 145 | help="audio processing start (secs)") 146 | parser.add_argument("--end", 147 | type=float, 148 | required=False, 149 | default=None, 150 | help="audio processing end (secs)") 151 | parser.add_argument("--num_spks", 152 | type=int, 153 | default=2, 154 | help="Number of the speakers") 155 | parser.add_argument("--device-id", 156 | type=int, 157 | default=-1, 158 | help="GPU-id to offload model to, -1 means " 159 | "running on CPU") 160 | parser.add_argument("--sr", 161 | type=int, 162 | default=16000, 163 | help="Sample rate for mixture input") 164 | parser.add_argument("--dump-dir", 165 | type=str, 166 | default="sep", 167 | help="Directory to dump separated speakers") 168 | parser.add_argument("--mvdr", 169 | type=bool, 170 | default=False, 171 | help="apply mvdr") 172 | args = parser.parse_args() 173 | 174 | run(args) 175 | -------------------------------------------------------------------------------- /css/css_with_conformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/css/css_with_conformer/utils/__init__.py -------------------------------------------------------------------------------- /css/css_with_conformer/utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import soundfile as sf 4 | import scipy.io.wavfile as wf 5 | 6 | MAX_INT16 = np.iinfo(np.int16).max 7 | EPSILON = np.finfo(np.float32).eps 8 | 9 | 10 | def _parse_script(scp_path, 11 | value_processor=lambda x: x, 12 | num_tokens=2, 13 | restrict=True): 14 | """ 15 | Parse kaldi's script(.scp) file 16 | If num_tokens >= 2, function will check token number 17 | """ 18 | scp_dict = dict() 19 | line = 0 20 | with open(scp_path, "r") as f: 21 | for raw_line in f: 22 | scp_tokens = raw_line.strip().split() 23 | line += 1 24 | if (num_tokens >= 2 and len(scp_tokens) != num_tokens) or ( 25 | restrict and len(scp_tokens) < 2): 26 | raise RuntimeError( 27 | "For {}, format error in line[{:d}]: {}".format( 28 | scp_path, line, raw_line)) 29 | if num_tokens == 2: 30 | key, value = scp_tokens 31 | else: 32 | key, value = scp_tokens[0], scp_tokens[1:] 33 | if key in scp_dict: 34 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 35 | key, scp_path)) 36 | scp_dict[key] = value_processor(value) 37 | return scp_dict 38 | 39 | 40 | class BaseReader(object): 41 | """ 42 | BaseReader Class 43 | """ 44 | def __init__(self, scp_rspecifier, **kwargs): 45 | self.index_dict = _parse_script(scp_rspecifier, **kwargs) 46 | self.index_keys = list(self.index_dict.keys()) 47 | 48 | def _load(self, key): 49 | # return path 50 | return self.index_dict[key] 51 | 52 | # number of utterance 53 | def __len__(self): 54 | return len(self.index_dict) 55 | 56 | # avoid key error 57 | def __contains__(self, key): 58 | return key in self.index_dict 59 | 60 | # sequential index 61 | def __iter__(self): 62 | for key in self.index_keys: 63 | yield key, self._load(key) 64 | 65 | 66 | class WaveReader(BaseReader): 67 | """ 68 | Sequential/Random Reader for single channel wave 69 | Format of wav.scp follows Kaldi's definition: 70 | key1 /path/to/wav 71 | ... 72 | """ 73 | def __init__(self, wav_scp, sr=16000, normalize=True): 74 | super(WaveReader, self).__init__(wav_scp) 75 | self.sr = sr 76 | self.normalize = normalize 77 | 78 | def _load(self, key): 79 | # return C x N or N 80 | sr, samps = read_wav(self.index_dict[key], 81 | normalize=self.normalize, 82 | return_rate=True) 83 | # if given samp_rate, check it 84 | if self.sr is not None and sr != self.sr: 85 | raise RuntimeError("Sample rate mismatch: {:d} vs {:d}".format( 86 | sr, self.sr)) 87 | 88 | return samps 89 | 90 | 91 | def read_wav(fname, beg=None, end=None, normalize=True, return_rate=False): 92 | """ 93 | Read wave files using scipy.io.wavfile(support multi-channel) 94 | """ 95 | # samps_int16: N x C or N 96 | # N: number of samples 97 | # C: number of channels 98 | if beg is not None: 99 | samps_int16, samp_rate = sf.read(fname, 100 | start=beg, 101 | stop=end, 102 | dtype="int16") 103 | else: 104 | samp_rate, samps_int16 = wf.read(fname) 105 | # N x C => C x N 106 | samps = samps_int16.astype(np.float32) 107 | # tranpose because I used to put channel axis first 108 | if samps.ndim != 1: 109 | samps = np.transpose(samps) 110 | # normalize like MATLAB and librosa 111 | if normalize: 112 | samps = samps / MAX_INT16 113 | if return_rate: 114 | return samp_rate, samps 115 | return samps 116 | 117 | 118 | def write_wav(fname, samps, sr=16000, normalize=True): 119 | """ 120 | Write wav files in int16, support single/multi-channel 121 | """ 122 | if normalize: 123 | samps = samps * MAX_INT16 124 | # scipy.io.wavfile.write could write single/multi-channel files 125 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 126 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 127 | samps = np.transpose(samps) 128 | samps = np.squeeze(samps) 129 | # same as MATLAB and kaldi 130 | samps_int16 = samps.astype(np.int16) 131 | fdir = os.path.dirname(fname) 132 | if fdir: 133 | os.makedirs(fdir, exist_ok=True) 134 | # NOTE: librosa 0.6.0 seems could not write non-float narray 135 | # so use scipy.io.wavfile instead 136 | wf.write(fname, sr, samps_int16) 137 | 138 | -------------------------------------------------------------------------------- /css/css_with_conformer/utils/mvdr_util.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | 5 | def make_mvdr(spk_masks, noise_masks, mix_wav = None, mix_stft = None, return_stft=False): 6 | """ 7 | 8 | Args: 9 | mix_wav: mixture waveform, [Nsamples, Mics] tensor 10 | spk_masks: [num_spks, F, T] tensor 11 | noise_masks: [num_noise, F, T] tensor 12 | mix_stft: mixture STFT, [Mics, F, T] complex tensor 13 | return_stft: if True, return the STFT of the separated signals. 14 | Otherwise, return the separated signals in the time domain. 15 | 16 | Returns: 17 | 18 | """ 19 | all_masks = make_wta(spk_masks, noise_masks) # [num_spks + 1_noise, F, T] 20 | if mix_stft is None: 21 | mix_stft=[] 22 | for i in range(7): 23 | st=librosa.core.stft(mix_wav[:, i], n_fft=512, hop_length=256) 24 | mix_stft.append(st) 25 | mix_stft=np.asarray(mix_stft) # [Mics, F, T] 26 | 27 | L = np.min([all_masks.shape[-1],mix_stft.shape[-1]]) 28 | mix_stft = mix_stft[:,:,:L] 29 | all_masks = all_masks[:,:,:L] 30 | 31 | scms = [get_mask_scm(mix_stft, mask) for mask in all_masks] 32 | spk_scms = np.stack(scms[:-1]) # [num_spks, F, 7, 7] 33 | noise_scm = scms[-1] # [F, 7, 7] 34 | 35 | res_per_spk = [] 36 | for i in range(spk_scms.shape[0]): 37 | # sum SCMs of all other speakers 38 | other_spks_scm = spk_scms[np.arange(spk_scms.shape[0]) != i].sum(axis=0) 39 | # add noise and compute beamforming coefficients for the current speaker 40 | coef = calc_bfcoeffs(noise_scm + other_spks_scm, spk_scms[i]) 41 | res = get_bf(mix_stft, coef) 42 | res_per_spk.append(res) 43 | 44 | if not return_stft: 45 | res_per_spk = [librosa.istft(res, hop_length=256) for res in res_per_spk] 46 | 47 | return res_per_spk 48 | 49 | 50 | def make_wta(spk_masks, noise_masks): 51 | noise_mask = noise_masks.sum(axis=0, keepdims=True) 52 | mask = np.vstack([spk_masks, noise_mask]) 53 | mask_max = np.amax(mask, axis=0, keepdims=True) 54 | mask = np.where(mask==mask_max, mask, 1e-10) 55 | return mask 56 | 57 | 58 | def get_mask_scm(mix,mask): 59 | """Return spatial covariance matrix of the masked signal.""" 60 | 61 | Ri = np.einsum('FT,FTM,FTm->FMm', 62 | mask, mix.transpose(1,2,0), mix.transpose(1,2,0).conj()) 63 | t1=np.eye(7) 64 | t2=t1[np.newaxis,:,:] 65 | Ri+=1e-15*t2 66 | return Ri # ,np.sum(mask) 67 | 68 | 69 | def calc_bfcoeffs(noi_scm,tgt_scm): 70 | # Calculate BF coeffs. 71 | num = np.linalg.solve(noi_scm, tgt_scm) 72 | den = np.trace(num, axis1=-2, axis2=-1)[..., np.newaxis, np.newaxis] 73 | den[0]+=1e-15 74 | W = (num / den)[..., 0] 75 | return W 76 | 77 | 78 | def get_bf(mix,W): 79 | c,f,t=mix.shape 80 | return np.sum(W.reshape(f,c,1).conj()*mix.transpose(1,0,2),axis=1) 81 | -------------------------------------------------------------------------------- /css/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import soundfile 5 | import torch 6 | import torch.nn as nn 7 | from pathlib import Path 8 | 9 | from css.training.train import TrainCfg, get_model 10 | from utils.conf import load_yaml_to_dataclass 11 | from utils.mic_array_model import multichannel_mic_pos_xyz_cm 12 | 13 | 14 | def load_css_model(model_dir: Path) -> (nn.Module, TrainCfg): 15 | """Load multi-channel (mc) or single-channel (sc) CSS model from checkpoint and yaml files.""" 16 | 17 | def fetch_one_file(path: Path, suffix: str): 18 | files = list(path.glob(suffix)) 19 | if len(files) == 0: 20 | raise FileNotFoundError(f'expecting at least one {suffix} file in {path}') 21 | assert len(files) == 1, f'expecting exactly one {suffix} file in {path}' 22 | return str(files[0]) 23 | 24 | yaml_path = fetch_one_file(model_dir, '*.yaml') 25 | checkpoint_path = fetch_one_file(model_dir, '*.pt') 26 | 27 | train_cfg = load_yaml_to_dataclass(yaml_path, TrainCfg) 28 | separator = get_model(train_cfg) 29 | 30 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 31 | 32 | def get_sub_state_dict(state_dict, prefix): 33 | # during training, model is wrapped in DP/DDP which introduces "module." prefix. remove it. 34 | return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} 35 | 36 | separator.load_state_dict(get_sub_state_dict(checkpoint["model"], "module.")) 37 | return separator, train_cfg 38 | 39 | 40 | def load_audio(wav_file_names: List, is_mc: bool) -> (np.ndarray, int): 41 | """Loads audio data from wav files and returns it as a numpy array. 42 | Args: 43 | wav_file_names: list of wav file names. 44 | is_mc: True if multi-channel, False if single-channel. 45 | Returns: 46 | mix_wav: input audio data [Batch, n_samples, n_channels]. 47 | sr: sample rate. 48 | """ 49 | 50 | dtype = 'float32' 51 | if is_mc: 52 | num_mics = len(multichannel_mic_pos_xyz_cm()) 53 | assert len(wav_file_names) == num_mics, f'expecting {num_mics} microphones' 54 | # Read audio data and sampling rates from all files 55 | audio_data, srs = zip(*[soundfile.read(wav_file, dtype=dtype) for wav_file in wav_file_names]) 56 | mix_wav = np.stack(audio_data, axis=-1)[np.newaxis, ...] # -> [Batch, n_samples, n_channels] 57 | assert mix_wav.ndim == 3 and mix_wav.shape[2] in (1, 7) 58 | sr = srs[0] 59 | else: 60 | assert len(wav_file_names) == 1 61 | mix_wav, sr = soundfile.read(wav_file_names[0], dtype=dtype) 62 | assert mix_wav.ndim == 1 63 | mix_wav = mix_wav[np.newaxis, :, np.newaxis] # [Batch, n_samples, n_channels] 64 | 65 | return mix_wav, sr 66 | -------------------------------------------------------------------------------- /css/training/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict 3 | 4 | 5 | class MicShiftAugmentation: 6 | """ 7 | Augments data by randomly shifting circular microphones 1-6 cyclically while preserving mic 0. 8 | Assumption: mics are ordered 0..7 in features. 9 | """ 10 | 11 | def __init__(self, seed: int, device: torch.device = torch.device('cpu')): 12 | self.rgen = torch.Generator(device=device) 13 | self.rgen.manual_seed(seed) 14 | 15 | def __call__(self, segment_batch: Dict) -> Dict: 16 | """Performs augmentation on a batch of segments. 17 | 18 | Args: 19 | segment_batch: a batch of segments, each segment is a dict of tensors. See the SimulatedDataset class for 20 | the expected keys. Note that the expected tensors are of shape [Batch, T, Mics] or 21 | [Batch, T, Mics, Spks], depending on the field. 22 | 23 | Returns: 24 | The batch with the same keys, but with the microphone arrays shifted. 25 | """ 26 | 27 | ignore_keys = ['utterance_id', 't0', 'seg_len', 'gt_spk_activity_scores',] 28 | # keys that require permutation 29 | mic_array_keys = ['mixture', 'gt_spk_direct_early_echoes', 'gt_spk_reverb', 'gt_noise'] 30 | 31 | not_covered = set(segment_batch) - set(ignore_keys + mic_array_keys) 32 | assert not not_covered, f'Unexpected keys! add them to ignore_keys, ' \ 33 | f'or to mic_array_keys and process them: {not_covered}' 34 | 35 | batch_size = segment_batch['mixture'].shape[0] 36 | shifts = torch.randint(0, 6, (batch_size,), generator=self.rgen, device=self.rgen.device) 37 | 38 | # shift all values by the same offset 39 | for key in mic_array_keys: 40 | if key in segment_batch: 41 | arr = segment_batch[key] 42 | assert arr.shape[2] == 7, 'expecting 7 microphones at dim 2' 43 | 44 | # Shift all mics except 0 45 | arr[:, :, 1:] = _batch_roll_dim2(arr[:, :, 1:], shifts) 46 | 47 | return segment_batch 48 | 49 | 50 | def _batch_roll_dim2(arr, shifts): 51 | """Rolls the values of the third dimension of a batch of tensors. 52 | 53 | Args: 54 | arr: The array of shape [Batch, T, Mics] or [Batch, T, Mics, Spks] to roll. 55 | shifts: The number of shifts to perform for each batch element. 56 | 57 | Returns: 58 | The rolled array. 59 | """ 60 | 61 | # Add a singleton dimension if needed 62 | orig_ndim = arr.ndim 63 | if orig_ndim == 3: 64 | arr = arr.unsqueeze(-1) 65 | 66 | # Assuming arr of shape [batch_size, mics, T, spks] 67 | batch_size, t, mics, spks = arr.shape 68 | 69 | # Create a grid of mic indices of the same shape as the input tensor 70 | indices = torch.arange(mics, device=arr.device)[None, None, :, None].repeat(batch_size, t, 1, spks) 71 | 72 | # Verify that shifts is a vector of the same size as the batch size 73 | assert shifts.shape == (batch_size,), f'Expecting shifts to be a vector of the same size as the batch size!' 74 | 75 | # Adjust indices for the shifts, ensuring wrapping around 76 | indices = (indices - shifts[:, None, None, None]) % mics 77 | 78 | # Gather the values from the input tensor according to the shifted indices. 79 | # The following will result in: 80 | # rolled[batch][t][mic][spk] = arr[batch][t][ indices[batch][t][mic] ][spk]. 81 | rolled = torch.gather(arr, 2, indices) 82 | 83 | # Remove the singleton dimension if needed 84 | if orig_ndim == 3: 85 | rolled = rolled.squeeze(-1) 86 | 87 | return rolled 88 | 89 | 90 | def test_batch_roll_dim2(): 91 | batch_size = 32 92 | t = 48000 93 | mics = 7 94 | spks = 3 95 | 96 | for with_spks in [True, False]: 97 | for i in range(100): 98 | m = torch.rand(batch_size, t, mics, spks) if with_spks else torch.rand(batch_size, t, mics) 99 | 100 | shifts = torch.randint(0, 6, (batch_size,)) 101 | 102 | # Fast version 103 | r1 = m.clone() 104 | r1[:, :, 1:] = _batch_roll_dim2(m[:, :, 1:], shifts) 105 | 106 | # Slow version 107 | r2 = m.clone() 108 | for b in range(batch_size): 109 | r2[b, :, 1:] = torch.roll(m[b, :, 1:], shifts=shifts[b].item(), dims=1) 110 | 111 | # Check that the results are the same 112 | assert (r1 == r2).all(), 'Failed!' 113 | 114 | print('batch_roll_dim2 test passed!') 115 | 116 | 117 | if __name__ == '__main__': 118 | test_batch_roll_dim2() 119 | -------------------------------------------------------------------------------- /css/training/conformer_wrapper.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field, asdict 2 | 3 | from css.css_with_conformer.executor.executor import Executor 4 | from css.css_with_conformer.nnet.conformer import ConformerCSS 5 | import torch as th 6 | import torch.nn as nn 7 | 8 | 9 | # The default values for the ExtractorCfg, ConformerCfg, and CssWithConformerCfg dataclasses were taken from the 10 | # conformer_base (MC) model. 11 | @dataclass 12 | class ExtractorCfg: 13 | ang_index: str = '' 14 | frame_hop: int = 256 15 | frame_len: int = 512 16 | ipd_cos: bool = False 17 | ipd_index: str = '1,0;2,0;3,0;4,0;5,0;6,0' 18 | ipd_mean_normalize: bool = True 19 | ipd_mean_normalize_version: int = 1 20 | log_spectrogram: bool = False 21 | mvn_spectrogram: bool = True 22 | num_spks: int = 2 23 | round_pow_of_two: bool = True 24 | window: str = 'hann' 25 | 26 | 27 | @dataclass 28 | class ConformerCfg: 29 | attention_dim: int = 256 30 | attention_heads: int = 4 31 | dropout_rate: float = 0.1 32 | kernel_size: int = 33 33 | linear_units: int = 1024 34 | num_blocks: int = 16 35 | 36 | 37 | @dataclass 38 | class NnetCfg: 39 | conformer_conf: ConformerCfg = field(default_factory=ConformerCfg) 40 | in_features: int = 1799 41 | num_nois: int = 1 # CSS_with_Conformer had 2 noise masks. NOTSOFAR has 1. 42 | num_spks: int = 3 # CSS_with_Conformer had 2 spk masks. NOTSOFAR has 3. 43 | 44 | 45 | @dataclass 46 | class ConformerCssCfg: 47 | extractor_conf: ExtractorCfg = field(default_factory=ExtractorCfg) 48 | nnet_conf: NnetCfg = field(default_factory=NnetCfg) 49 | 50 | 51 | class ConformerCssWrapper(nn.Module): 52 | """A thin wrapper around the Executor class, built to minimize code changes and to accept CssWithConformerCfg.""" 53 | def __init__(self, cfg: ConformerCssCfg): 54 | super().__init__() 55 | nnet = ConformerCSS(**asdict(cfg.nnet_conf)) 56 | self.executor = Executor(nnet, extractor_kwargs=asdict(cfg.extractor_conf), get_mask=True) 57 | 58 | def forward(self, mix: th.Tensor): 59 | """Compute the masks for the given time domain mixture. 60 | A simple composition of stft->separate methods. 61 | 62 | Args: 63 | mix (th.Tensor): The mixture of shape [Batch, T, Mics], in time domain, to compute the masks for. 64 | Returns: A dictionary with these keys: 65 | 'spk_masks' (th.Tensor): The masks for the speakers. Shape: [Batch, F, T, num_spks]. 66 | 'noise_masks' (th.Tensor): The masks for the noise. Shape: [Batch, F, T, num_nois]. 67 | """ 68 | assert (mix.shape[2] == 1) == (self.executor.extractor.ipd_extractor is None), \ 69 | (f"IPD extractor is expected iff the number of microphones is greater than 1. " 70 | f"This may indicate model misconfiguration!") 71 | 72 | if mix.shape[2] == 1: 73 | mix = mix.squeeze(2) # [Batch, T, 1] -> [Batch, T] 74 | 75 | stft = self.stft(mix) 76 | res = self.separate(stft) 77 | return res 78 | 79 | def separate(self, stft: th.Tensor): 80 | """Compute separation masks for the given signal represented as stft (result of self.stft). 81 | 82 | Args: 83 | stft (th.Tensor): complex stft tensor of shape 84 | [Batch, F, T, Mics] (multi-channel) or [Batch, F, T] (single-channel) 85 | Returns: A dictionary with these keys: 86 | 'spk_masks' (th.Tensor): The masks for the speakers. Shape: [Batch, F, T, num_spks]. 87 | 'noise_masks' (th.Tensor): The masks for the noise. Shape: [Batch, F, T, num_nois]. 88 | """ 89 | assert th.is_complex(stft) 90 | if stft.dim() == 4: 91 | stft = stft.moveaxis(3, 1).contiguous() 92 | # [Batch, F, T, Mics] -> [Batch, Mics, F, T], and make contiguous. 93 | 94 | res = self.executor({"mix": None, 'mag': stft.abs(), 'pha': stft.angle()}) 95 | 96 | all_masks = th.cat([m.unsqueeze(-1) for m in res], dim=-1) 97 | 98 | assert all_masks.shape[-1] == self.executor.nnet.num_spks + self.executor.nnet.num_nois, \ 99 | f"Expected {self.executor.nnet.num_spks + self.executor.nnet.num_nois} masks, got {all_masks.shape[-1]}!" 100 | 101 | return { 102 | 'spk_masks': all_masks[..., :self.executor.nnet.num_spks], 103 | 'noise_masks': all_masks[..., self.executor.nnet.num_spks:] 104 | } 105 | 106 | def stft(self, s: th.Tensor): 107 | """Compute the STFT of a signal. 108 | 109 | Args: 110 | s (th.Tensor): The time domain signal to compute the STFT of. 111 | Shape: [Batch, T, Mics] for multi-channel. 112 | [Batch, T] for single-channel. 113 | 114 | Returns: 115 | A tensor of shape [Batch, F, T, Mics] and type th.complex64. 116 | """ 117 | 118 | if s.dim() == 3: 119 | s = s.moveaxis(1, 2).contiguous() # [Batch, T, Mics] -> [Batch, Mics, T], and make contiguous. 120 | 121 | mag, phase = self.executor.extractor.stft(s, cplx=False) # -> (mag, phase) tuple of [Batch, Mics, F, T] 122 | 123 | # to complex stft tensor 124 | stft_cplx = th.polar(mag, phase) # -> [Batch, Mics, F, T] 125 | 126 | if s.dim() == 3: 127 | stft_cplx = stft_cplx.moveaxis(1, 3).contiguous() # -> [Batch, F, T, Mics] 128 | 129 | return stft_cplx # -> [Batch, F, T, Mics], or [Batch, F, T] 130 | 131 | def istft(self, stft: th.Tensor): 132 | """Compute the inverse STFT of a signal. 133 | 134 | Args: 135 | stft (th.Tensor): The complex signal to compute the iSTFT of, with shape [Batch, F, T]. 136 | 137 | Returns: 138 | Time domain signal as [Batch, NSamples] tensor. 139 | """ 140 | assert th.is_complex(stft) 141 | assert stft.dim() == 3 142 | mag, phase = stft.abs(), stft.angle() 143 | 144 | res = self.executor.extractor.istft(mag, phase, cplx=False) 145 | 146 | return res 147 | 148 | 149 | # TODO: Remove before release. 150 | class DummyCss(nn.Module): 151 | """A dummy CSS model that does nothing.""" 152 | def __init__(self): 153 | super().__init__() 154 | 155 | l = nn.Linear(4096, 4096) 156 | layers = [] 157 | for i in range(5000): 158 | layers.append(l) 159 | layers.append(nn.ReLU()) 160 | 161 | self.seq = nn.Sequential(*layers) 162 | 163 | 164 | def forward(self, mix: th.Tensor): 165 | 166 | # Flatten the mix, except the batch dimension. 167 | mix = mix.flatten(start_dim=1) 168 | # Take the first items to match the expected input size of the linear1 layer. 169 | mix = mix[:, :4096] 170 | # Pass the mix through the linear layers, with relu in between. 171 | mix = self.seq(mix) 172 | 173 | return { 174 | 'spk_masks': mix, 175 | 'noise_masks': mix+1 176 | } -------------------------------------------------------------------------------- /css/training/losses.py: -------------------------------------------------------------------------------- 1 | from itertools import permutations 2 | from typing import Callable 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from scipy.optimize import linear_sum_assignment 9 | 10 | 11 | class PitWrapper(nn.Module): 12 | """ 13 | Permutation Invariant Wrapper to allow Permutation Invariant Training 14 | (PIT) with existing losses. 15 | Permutation invariance is calculated over the sources axis which is 16 | assumed to be the rightmost dimension. 17 | Predictions and targets tensors are assumed to have shape [batch, ..., sources]. 18 | """ 19 | 20 | def __init__(self, base_loss: Callable): 21 | """ 22 | Args: 23 | base_loss (callable): 24 | Base loss function, e.g. torch.nn.MSELoss. It is assumed that it takes 25 | two arguments: 26 | predictions and targets and no reduction is performed. 27 | (if a pytorch loss is used, the user must specify reduction="none"). 28 | """ 29 | super(PitWrapper, self).__init__() 30 | self.base_loss = base_loss 31 | 32 | def _fast_pit(self, loss_mat): 33 | """ 34 | Args: 35 | loss_mat : Tensor of shape [sources, source] containing loss values for each 36 | prediction-target pair. 37 | 38 | Returns: 39 | loss : torch.Tensor, scalar, minimum loss over all permutations. 40 | target_perm (list) : Optimial permutation i.e. loss(predictions, targets[:, target_perm]) 41 | returns the minimum loss. 42 | """ 43 | left_inds, right_inds = linear_sum_assignment(loss_mat.data.cpu()) 44 | assert (left_inds == range(len(left_inds))).all() 45 | target_perm = right_inds 46 | loss = loss_mat[left_inds, right_inds].mean() 47 | 48 | return loss, target_perm 49 | 50 | def _opt_perm_loss(self, pred, target): 51 | n_sources = target.shape[-1] 52 | 53 | # [..., sources] -> [..., sources, sources] replicated along second to last dim 54 | ones = [1] * (len(target.shape) - 1) 55 | target = target.unsqueeze(-2).repeat(*ones, n_sources, 1) 56 | 57 | # [..., sources] -> [..., sources, sources] replicated along last dim 58 | ones = [1] * (len(pred.shape) - 1) 59 | pred = pred.unsqueeze(-1).repeat(1, *ones, n_sources) 60 | 61 | loss_mat = self.base_loss(pred, target) 62 | 63 | assert ( 64 | len(loss_mat.shape) >= 2 and loss_mat.shape[-2:] == target.shape[-2:] 65 | ), "Base loss should not perform any reduction operation" 66 | mean_over = tuple(range(loss_mat.dim() - 2)) # all but the last two dims 67 | if mean_over: 68 | loss_mat = loss_mat.mean(dim=mean_over) 69 | # loss_mat: [sources, sources] 70 | 71 | return self._fast_pit(loss_mat) 72 | 73 | def forward(self, preds, targets): 74 | """ 75 | Args: 76 | preds: predictions tensor, of shape [batch, ..., sources]. 77 | targets: targets tensor, of shape [batch, ..., sources]. 78 | 79 | Returns: 80 | ------- 81 | loss : Permutation invariant loss per instance, tensor of shape [batch]. 82 | perms (list) : 83 | List of indexes for optimal permutation of the targets per instance. 84 | Example: [(0, 1, 2), (2, 1, 0)] for three sources and batch size 2. 85 | """ 86 | losses = [] 87 | perms = [] 88 | 89 | assert preds.shape[-1] == targets.shape[-1], \ 90 | "preds and targets expected to be padded to the same number of sources" 91 | 92 | for pred, label in zip(preds, targets): 93 | loss, p = self._opt_perm_loss(pred, label) 94 | perms.append(p) 95 | losses.append(loss) 96 | loss = torch.stack(losses) 97 | return loss, perms 98 | 99 | 100 | def mse_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 101 | """Computes MSE loss without any reduction.""" 102 | return F.mse_loss(pred, target, reduction="none") 103 | 104 | def l1_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 105 | """Computes L1 loss without any reduction.""" 106 | return F.l1_loss(pred, target, reduction="none") 107 | 108 | 109 | def test_pit_wrapper(): 110 | pit_mse = PitWrapper(mse_loss) 111 | 112 | torch.manual_seed(43236) 113 | 114 | for i in range(20): 115 | # (batch, time, freq, sources) 116 | targets = torch.rand((2, 100, 257, 4)) 117 | p = (3, 0, 2, 1) 118 | predictions = targets[..., p] 119 | loss, target_perm = pit_mse(predictions, targets) 120 | 121 | assert (loss == 0.).all() 122 | assert (predictions[0] == targets[0,..., target_perm[0]]).all() 123 | assert (np.stack(target_perm) == np.stack([p, p])).all() 124 | 125 | 126 | if __name__ == "__main__": 127 | test_pit_wrapper() -------------------------------------------------------------------------------- /css/training/schedulers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass 3 | import torch.optim.lr_scheduler 4 | 5 | 6 | @dataclass 7 | class LinearWarmupDecayCfg: 8 | # Defaults are set according to the CSS with Conformer paper 9 | warmup: int = 10000 10 | decay: int = 260000 11 | 12 | 13 | class LinearWarmupDecayScheduler(torch.optim.lr_scheduler.LambdaLR): 14 | def __init__(self, optimizer, cfg: LinearWarmupDecayCfg, verbose=False): 15 | self.cfg = cfg 16 | super().__init__(optimizer, self._lr_lambda, verbose=verbose) 17 | 18 | def _lr_lambda(self, step): 19 | if step < self.cfg.warmup: 20 | res = step / self.cfg.warmup 21 | elif step < self.cfg.warmup + self.cfg.decay: 22 | res = 1 - (step - self.cfg.warmup) / self.cfg.decay 23 | else: 24 | if step > self.cfg.warmup + self.cfg.decay: 25 | warnings.warn(f'Learning rate has been decayed to zero! {step=}') 26 | res = 0 27 | 28 | if self.verbose: 29 | print(f'LinearWarmupDecayScheduler: {step=} {res=}') 30 | 31 | return res 32 | -------------------------------------------------------------------------------- /diarization/diarization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | from diarization.diarization_common import DiarizationCfg 7 | from diarization.time_based_diarization import time_based_diarization 8 | from diarization.word_based_diarization import word_based_clustering 9 | from utils.logging_def import get_logger 10 | from utils.torch_utils import get_world_size 11 | 12 | _LOG = get_logger('diarization') 13 | 14 | 15 | def diarization_inference(out_dir: str, segments_df: pd.DataFrame, cfg: DiarizationCfg, 16 | fetch_from_cache: bool, device: Optional[str] = None) -> pd.DataFrame: 17 | """ 18 | Run diarization to assign a speaker label to each ASR word. 19 | 20 | Two diarization modes are supported: 21 | 1. Pre-SR diarization that runs diarization without the knowledge of ASR. 22 | In this mode, we directly call NeMo's diarization recipes, such as NMESC or NMESCC 23 | followed by MSDD. Then, for each ASR word, the speaker that is the most active within 24 | the word's time boundaries is assigned to the word. 25 | Set cfg.method to "nmesc" to use NMESC recipe of NeMo in the config file. 26 | Set cfg.method to "nmesc_msdd" to use the NMESC followed by MSDD recipe of NeMo. 27 | 2. Post-SR diarization that runs diarization after ASR. Allows the use of word boundaries. 28 | In this mode, we extract a speaker embedding vector for each word, and then call 29 | NeMo's NMESC for clustering. We also adopted the multi-scale speaker embedding window 30 | concept from NeMo, and extract multiple scale speaker embedding vectors for each word, 31 | each scale using different window sizes. The final affinity matrix is a simple average 32 | of the affinity matrixces of all the scales. 33 | To use this mode, set cfg.method to "word_nmesc". 34 | 35 | A known limitation of the diarization baseline is that the words from the CSS streams 36 | are pooled and clustered, and stream ID is not used in clustering. It is possible that 37 | words from different streams that overlap in time are assigned to the same speaker. 38 | This will trigger warning in tcp_wer and tcorc_wer computation and potentially degrade results. 39 | 40 | Args: 41 | out_dir: the directory to store generated files in the diarization step. 42 | This allows the cache of files and skip some steps when the code is run again. 43 | segments_df: a dataframe of transcribed segments for a given session, with columns: 44 | 'start_time': start time of the segment in seconds. 45 | 'end_time': end time of the segment in seconds. 46 | 'text': the text of the segment. 47 | 'word_timing': a list of [word, start, end] lists. 48 | 'meeting_id': the meeting id. 49 | 'session_id': the session id. 50 | 'wav_file_name': the name of the wav file that the segment was transcribed from. 51 | this is typically points to the speech separated wav file (see CSS module). 52 | cfg: diarization configuration. 53 | fetch_from_cache: If True, returns the cached results if they exist. Otherwise, runs the inference. 54 | device: the device to use for loading the model and running inference. 55 | Returns: 56 | attributed_segments_df: a new set of segments with 'speaker_id' column added. 57 | """ 58 | 59 | _LOG.info("Running Speaker Diarization") 60 | 61 | assert segments_df.session_id.nunique() <= 1, 'no cross-session information is permitted' 62 | 63 | # these two modes are for debugging and analysis 64 | if cfg.method == "skip": 65 | _LOG.info("Skipping Diarization") 66 | attributed_segments_df = segments_df.copy() 67 | attributed_segments_df['speaker_id'] = 'spk0' 68 | return attributed_segments_df 69 | elif cfg.method == "by_wav_file_name": 70 | attributed_segments_df = segments_df.copy() 71 | # map each unique wav_file_name to an index 72 | wav_file_name_ind, uniques = pd.factorize(attributed_segments_df['wav_file_name'], sort=True) 73 | attributed_segments_df['speaker_id'] = wav_file_name_ind 74 | attributed_segments_df['speaker_id'] = 'wav_' + attributed_segments_df['speaker_id'].astype(str) 75 | _LOG.info(f"Diarization by wav file names: {uniques}") 76 | return attributed_segments_df 77 | 78 | session_name = segments_df.session_id[0] 79 | is_ct = session_name.startswith('close_talk') 80 | assert segments_df.wav_file_name.nunique() <= 3 or is_ct, 'expecting at most three separated channels' 81 | output_dir = Path(out_dir) / "diarization" / session_name / cfg.method 82 | out_file = output_dir / "all_segments_df.pkl" 83 | 84 | # Skip cache and writing ops if running in DDP mode, it is necessary to continue evaluate the model on each device 85 | skip_cache_and_write = get_world_size() > 1 86 | 87 | if not skip_cache_and_write: 88 | if fetch_from_cache and out_file.exists(): 89 | attributed_segments_df = pd.read_pickle(out_file) 90 | return attributed_segments_df 91 | os.makedirs(output_dir, exist_ok=True) 92 | 93 | segments_df = segments_df.copy() 94 | # wav_file_name as category to convert to indices 95 | segments_df['wav_file_name'] = segments_df['wav_file_name'].astype('category') 96 | assert 'wav_file_name_ind' not in segments_df 97 | segments_df['wav_file_name_ind'] = segments_df['wav_file_name'].cat.codes 98 | wav_files = segments_df['wav_file_name'].cat.categories.to_list() 99 | 100 | if cfg.method == "word_nmesc": 101 | attributed_segments_df = word_based_clustering(wav_files, segments_df, cfg, device) 102 | else: 103 | attributed_segments_df = time_based_diarization(wav_files, segments_df, str(output_dir), cfg) 104 | 105 | if not skip_cache_and_write: 106 | attributed_segments_df.to_pickle(out_file) 107 | _LOG.info(f'Speaker Diarization saved to {out_file}') 108 | 109 | return attributed_segments_df 110 | -------------------------------------------------------------------------------- /diarization/diarization_common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from dataclasses import dataclass, field 5 | 6 | 7 | # diarization inference configuration 8 | @dataclass 9 | class DiarizationCfg: 10 | method: str = "nmesc" # choose from "nmesc", "nmesc_msdd", "word_nmesc", or "skip" 11 | min_embedding_windows: list = field(default_factory=list) 12 | max_allowed_word_duration: float = 3 # maximum allowed word duration. If word is longer than this value, ignore it. 13 | apply_deduplication: bool = True 14 | embedding_model_name: str = "titanet_large" 15 | msdd_model_name: str = "diar_msdd_telephonic" 16 | # vad_model_name: str = "vad_telephony_marblenet" # 8kHz telephone 17 | vad_model_name: str = "vad_multilingual_marblenet" # 16kHz 18 | 19 | 20 | def merge_words_to_segments_by_spk_change(all_words: list) -> dict: 21 | if len(all_words) == 0: 22 | return [] 23 | if len(all_words) == 1: 24 | return all_words 25 | 26 | segments = {"word_timing": [], 27 | "speaker_id": []} 28 | seg_start = 0 29 | for i, word in enumerate(all_words): 30 | # if speaker ID is changed or channel ID is changed, break. This makes sure that each segment 31 | # contains words from a single channel, so the segments will be safe to compute tcorc_wer. 32 | if i > 0 and (word[-1] != all_words[seg_start][-1] or word[-2] != all_words[seg_start][-2]): 33 | seg_words = all_words[seg_start: i] 34 | segments["word_timing"].append([w[:-1] for w in seg_words]) 35 | segments["speaker_id"].append(seg_words[0][-1]) 36 | seg_start = i 37 | segments["word_timing"].append([w[:-1] for w in all_words[seg_start:]]) 38 | segments["speaker_id"].append(all_words[seg_start][-1]) 39 | 40 | return segments 41 | 42 | 43 | def compute_overlap_ratio(start1, end1, start2, end2): 44 | latest_start = max(start1, start2) 45 | earliest_end = min(end1, end2) 46 | overlap = earliest_end - latest_start 47 | 48 | if overlap < 0: 49 | return 0 # No overlap 50 | 51 | duration1 = end1 - start1 52 | duration2 = end2 - start2 53 | longer_duration = max(duration1, duration2) 54 | 55 | return overlap / longer_duration 56 | 57 | 58 | def deduplicate(all_words_sorted, overlap_threshold=0.5): 59 | all_words_deduplicated = [] 60 | for i, curr_word in enumerate(all_words_sorted): 61 | if i == 0: 62 | continue 63 | prev_word = all_words_sorted[i-1] 64 | skip_word = False 65 | if curr_word[0] == prev_word[0] and curr_word[4] == prev_word[4]: 66 | overlap_ratio = compute_overlap_ratio(curr_word[1], curr_word[2], prev_word[1], prev_word[2]) 67 | if overlap_ratio > overlap_threshold: 68 | # if identifical words belong to the same speaker and are appearing 69 | # in multiple unmixed channels, and they have more than 50% overlapped, 70 | # only keep the first one. 71 | skip_word = True 72 | if not skip_word: 73 | all_words_deduplicated.append(curr_word) 74 | 75 | return all_words_deduplicated 76 | 77 | 78 | def prepare_diarized_data_frame(all_words, segments_df, apply_deduplication): 79 | # cut word sequence into segments according to speaker change 80 | all_words_sorted = sorted(all_words, key=lambda x:x[2]) # sort words by end time 81 | if apply_deduplication: 82 | final_words = deduplicate(all_words_sorted) 83 | else: 84 | final_words = all_words_sorted 85 | segments = merge_words_to_segments_by_spk_change(final_words) 86 | 87 | diarized_segments_df = pd.DataFrame( 88 | {'start_time': [seg[0][1] for seg in segments["word_timing"]], 89 | 'end_time': [seg[-1][2] for seg in segments["word_timing"]], 90 | 'text': ["".join([w[0] for w in seg]) for seg in segments["word_timing"]], 91 | 'word_timing': segments["word_timing"]}) 92 | 93 | diarized_segments_df['meeting_id'] = segments_df['meeting_id'][0] 94 | diarized_segments_df['session_id'] = segments_df['session_id'][0] 95 | 96 | # assign correct CSS file name to each diarized segment 97 | stream_id = [seg[0][-1] for seg in diarized_segments_df.word_timing.to_list()] 98 | diarized_segments_df['wav_file_name'] = segments_df['wav_file_name'].cat.categories[stream_id] 99 | 100 | diarized_segments_df['speaker_id'] = segments["speaker_id"] 101 | 102 | return diarized_segments_df 103 | -------------------------------------------------------------------------------- /diarization/time_based_diarization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import glob 4 | import shutil 5 | 6 | import pandas as pd 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | 10 | from nemo.collections.asr.models import ClusteringDiarizer 11 | from nemo.collections.asr.models.msdd_models import NeuralDiarizer 12 | 13 | from utils.audio_utils import read_wav, write_wav 14 | from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg 15 | from utils.logging_def import get_logger 16 | 17 | _LOG = get_logger('time_based_diarization') 18 | 19 | 20 | def run_nemo_diarization(audio_files: list, session_output_dir: str, cfg: DiarizationCfg, vad_time_resolution: float=0.01): 21 | """ 22 | Run the diarization recipes from the NeMo toolkit. Two recipes can be used: NMESC and NMESC + MSDD. 23 | 24 | After diarization, represent the diarization results as a CxSxT tensor, where C is the number of unmixed 25 | channels, S is the number of diarized speakers, and T is the number of frames (defined by vad_time_resolution). 26 | """ 27 | os.makedirs(session_output_dir, exist_ok=True) 28 | num_audio_files = len(audio_files) 29 | if num_audio_files > 1: 30 | # if there are more than one unmixed channels, concatenate them and diarize. 31 | # Note that the time order information of the speech segments may not be proporly used in the diarization. 32 | wavs = [read_wav(audio_file, normalize=True) for audio_file in audio_files] 33 | audio_file_to_diarize = os.path.join(session_output_dir, "concatenated.wav") 34 | write_wav(audio_file_to_diarize, np.hstack(wavs)) 35 | else: 36 | audio_file_to_diarize = audio_files[0] 37 | 38 | manifest = {"audio_filepath": audio_file_to_diarize, 39 | "offset": 0, 40 | "duration": None, 41 | "label": "infer", 42 | "text": "-", 43 | "num_speakers": None, 44 | "rttm_filepath": None, 45 | "uem_filepath": None, 46 | } 47 | manifest_file = os.path.join(session_output_dir, "manifest.json") 48 | json.dump(manifest, open(manifest_file, "w")) # don't use indent 49 | 50 | if cfg.method == "nmesc": 51 | # config_file = "configs/inference/diarization/nemo/diar_infer_general.yaml" 52 | # config_file = "configs/inference/diarization/nemo/diar_infer_telephonic.yaml" 53 | config_file = "configs/inference/diarization/nemo/diar_infer_meeting.yaml" 54 | nemo_conf = OmegaConf.load(config_file) 55 | nemo_conf.diarizer["manifest_filepath"] = manifest_file 56 | nemo_conf.diarizer["out_dir"] = session_output_dir 57 | nemo_conf.diarizer["vad"]["model_path"] = cfg.vad_model_name 58 | nemo_conf.diarizer["speaker_embeddings"]["model_path"] = cfg.embedding_model_name 59 | 60 | sd_model = ClusteringDiarizer(cfg=nemo_conf).to(nemo_conf.device) 61 | sd_model.diarize() 62 | 63 | elif cfg.method == "nmesc_msdd": 64 | # config_file = "configs/inference/diarization/nemo/diar_infer_general.yaml" 65 | config_file = "configs/inference/diarization/nemo/diar_infer_telephonic.yaml" # so far only this config works with MSDD 66 | # config_file = "configs/inference/diarization/nemo/diar_infer_meeting.yaml" 67 | nemo_conf = OmegaConf.load(config_file) 68 | nemo_conf.diarizer["manifest_filepath"] = manifest_file 69 | nemo_conf.diarizer["out_dir"] = session_output_dir 70 | nemo_conf.diarizer["vad"]["model_path"] = cfg.vad_model_name 71 | nemo_conf.diarizer["speaker_embeddings"]["model_path"] = cfg.embedding_model_name 72 | nemo_conf.diarizer["msdd_model"]["model_path"] = cfg.msdd_model_name 73 | 74 | diarizer_model = NeuralDiarizer(cfg=nemo_conf).to(nemo_conf.device) 75 | diarizer_model.diarize() 76 | 77 | else: 78 | raise ValueError(f"Unknown diarization method {cfg.method}!") 79 | 80 | # load diarization results from NeMo 81 | rttm_file = glob.glob(os.path.join(session_output_dir, "pred_rttms", "*.rttm")) 82 | if len(rttm_file) == 0: 83 | raise Exception("Diarization RTTM file is not created successfully!") 84 | elif len(rttm_file) > 1: 85 | raise Exception("More than one RTTM file found, expect only 1.") 86 | 87 | with open(rttm_file[0]) as file: 88 | sys_rttm = [line.rstrip('\n').split() for line in file] 89 | diarized_segments = [[float(seg[3]), float(seg[4]), seg[7]] for seg in sys_rttm] 90 | diarized_spk_uniq = sorted(list(set([seg[-1] for seg in diarized_segments]))) 91 | 92 | # represent diarization results as a global frame-based speaker VAD matrix. 93 | # The size of the speaker VAD matrix is SxT, where S is the number of diarized speakers, 94 | # and T0 is the number of frames. In the case of multiple unmixed channels, T0 is the sum 95 | # of the duration of the unmixed channels. 96 | max_time = np.max(np.array([seg[:2] for seg in diarized_segments])) 97 | max_vad_frame = int(np.ceil(max_time / vad_time_resolution)) 98 | spk_vad = np.zeros((len(diarized_spk_uniq), max_vad_frame)) 99 | for seg in diarized_segments: 100 | start_frame = int(np.round(seg[0]/vad_time_resolution)) 101 | end_frame = int(np.round((seg[0]+seg[1])/vad_time_resolution)) 102 | spk_idx = diarized_spk_uniq.index(seg[2]) 103 | spk_vad[spk_idx, start_frame: end_frame] = 1 104 | 105 | # convert the global speaker VAD matrix to channel based speaker VAD matrices with size 106 | # CxSxT, where C is the number of unmixed channels, and T is from the duration of one unmixed channel. 107 | if num_audio_files > 1: 108 | # divide the global speaker VAD matrix into channel dependent speaker VAD matrices 109 | max_channel_vad_frame = int(max_vad_frame / num_audio_files) 110 | channel_spk_vad = np.zeros((num_audio_files, len(diarized_spk_uniq), max_channel_vad_frame)) 111 | for i in range(num_audio_files): 112 | tmp_vad = spk_vad[:, i*max_channel_vad_frame:(i+1)*max_channel_vad_frame] 113 | channel_spk_vad[i, :, :tmp_vad.shape[1]] = tmp_vad 114 | else: 115 | channel_spk_vad = spk_vad[np.newaxis] 116 | 117 | return channel_spk_vad 118 | 119 | 120 | def assign_words_to_speakers(segments_df: pd.DataFrame, spk_vad: np.array, apply_deduplication: bool, vad_time_resolution: float=0.01) -> pd.DataFrame: 121 | """ 122 | Given the diarization output and ASR word boundary information, assign an ASR word to the diarized speaker that is the 123 | most active during the word's time interval. 124 | """ 125 | has_unassigned_word = False 126 | all_words = [] 127 | for _, seg in segments_df.iterrows(): 128 | # get the unmixed channel id for current segment 129 | channel_id = seg.wav_file_name_ind 130 | 131 | for i, word in enumerate(seg["word_timing"]): 132 | start_frame = int(np.round(word[1]/vad_time_resolution)) 133 | end_frame = int(np.round(word[2]/vad_time_resolution)) 134 | end_frame = np.maximum(start_frame+1, end_frame) # make sure there is at least one frame for each word 135 | 136 | word_spk_count = spk_vad[channel_id][:, start_frame: end_frame] 137 | avg_word_spk_count = np.mean(word_spk_count, axis=1) 138 | if np.sum(avg_word_spk_count) == 0: # no valid speaker count from diarization 139 | all_words.append(word+[channel_id, None]) 140 | has_unassigned_word = True 141 | else: 142 | most_prob_spk_idx = np.argmax(avg_word_spk_count) 143 | all_words.append(word+[channel_id, f"spk{most_prob_spk_idx}"]) 144 | 145 | if has_unassigned_word: 146 | word_middle_times = [np.mean(word[1:3]) for word in all_words if word[-1] is not None] 147 | word_spk_ids = [word[-1] for word in all_words if word[-1] is not None] 148 | 149 | for word in all_words: 150 | if word[-1] is None: 151 | # if a word is not assigned a speaker label for some reason, use the speaker label of the nearest word 152 | word_middle_time = np.mean(word[1:3]) 153 | time_diff = np.abs(word_middle_times - word_middle_time) 154 | closest_word_idx = np.argmin(time_diff) 155 | word[-1] = word_spk_ids[closest_word_idx] 156 | _LOG.info(f"Word ({word[0]}, {word[1]:.2f}, {word[2]:.2f}) borrowed speaker ID ({word[-1]}) from word centered at {word_middle_times[closest_word_idx]:.2f}s. Time diff = {time_diff[closest_word_idx]:.2f}") 157 | 158 | diarized_segments_df = prepare_diarized_data_frame(all_words, segments_df, apply_deduplication) 159 | 160 | return diarized_segments_df 161 | 162 | 163 | def time_based_diarization(wav_files_sorted, segments_df, output_dir, cfg): 164 | """ 165 | Run NeMo diarization recipes. Combine the ASR words boundary information and diarizaiton output 166 | to add a speaker label to each recognized word. 167 | """ 168 | # Step 1. Run NeMo diarization 169 | channel_spk_vad = run_nemo_diarization(wav_files_sorted, output_dir, cfg) 170 | 171 | # Step 2. Assign ASR words to diarized speakers 172 | attributed_segments_df = assign_words_to_speakers(segments_df, channel_spk_vad, cfg.apply_deduplication) 173 | 174 | return attributed_segments_df 175 | -------------------------------------------------------------------------------- /diarization/word_based_diarization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | import pandas as pd 4 | import numpy as np 5 | import torch 6 | from torch.cuda.amp import autocast 7 | from tqdm import tqdm 8 | 9 | from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel 10 | from nemo.collections.asr.parts.utils.offline_clustering import NMESC, SpectralClustering, cos_similarity, getCosAffinityMatrix, getAffinityGraphMat 11 | 12 | from utils.audio_utils import read_wav 13 | from utils.torch_utils import is_dist_initialized 14 | from diarization.diarization import DiarizationCfg 15 | from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg 16 | from utils.logging_def import get_logger 17 | from torch.nn.utils.rnn import pad_sequence 18 | _LOG = get_logger('word_based_diarization') 19 | 20 | 21 | def load_speaker_model(model_name: str, device: str): 22 | """ 23 | Load speaker embedding model defined in the NeMo toolkit. 24 | """ 25 | _LOG.info("Loading pretrained {} model from NGC".format(model_name)) 26 | spk_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_name, map_location=device) 27 | spk_model.eval() 28 | 29 | return spk_model 30 | 31 | 32 | def run_clustering(raw_affinity_mat: np.array, max_num_speakers: int=8, max_rp_threshold: float=0.06, sparse_search_volume: int=30): 33 | """ 34 | Run NMESC using the implementation from NeMo toolkit. 35 | """ 36 | nmesc = NMESC( 37 | raw_affinity_mat, 38 | max_num_speakers=max_num_speakers, 39 | max_rp_threshold=max_rp_threshold, 40 | sparse_search_volume=sparse_search_volume, 41 | ) 42 | 43 | est_num_of_spk, p_hat_value = nmesc.forward() 44 | affinity_mat = getAffinityGraphMat(raw_affinity_mat, p_hat_value) 45 | n_clusters = int(est_num_of_spk.item()) 46 | 47 | spectral_model = SpectralClustering(n_clusters=n_clusters) 48 | cluster_label = spectral_model.forward(affinity_mat) 49 | 50 | return cluster_label 51 | 52 | 53 | def batch_generator(data, batch_size): 54 | for i in range(0, len(data), batch_size): 55 | yield data[i:i + batch_size] 56 | 57 | 58 | def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3, batch_size=32): 59 | """ 60 | For each word, use its word boundary information to extract multi-scale speaker embedding vectors. 61 | """ 62 | wav_duration = wavs[0].size / sr 63 | 64 | all_words = [] 65 | all_word_embeddings = [] 66 | too_long_words = [] 67 | 68 | n_words = sum(len(seg['word_timing']) for _, seg in segments_df.iterrows()) 69 | _segments_df, _ = _fill_dummy_words_for_ddp(segments_df) 70 | words_processed = 0 71 | 72 | for _, seg in tqdm(_segments_df.iterrows(), desc='extracting speaker embedding for segments', total=len(_segments_df)): 73 | # get the unmixed channel id for current segment 74 | channel_id = seg.wav_file_name_ind 75 | 76 | for words_batch in batch_generator(seg["word_timing"], batch_size): 77 | word_embedding = [] 78 | word_wavs = [] 79 | word_lens = [] 80 | for word in words_batch: 81 | start_time = word[1] 82 | end_time = word[2] 83 | center_time = (start_time + end_time) / 2 84 | word_duration = end_time - start_time 85 | 86 | # extract multi-scale speaker embedding for the word 87 | for min_window_size in min_embedding_windows: 88 | if word_duration < min_window_size: 89 | # if the word duration is shorter than the window size, use a window centered at the word. 90 | # The window may cover other neighboring words 91 | start_time2 = np.maximum(0, center_time - min_window_size/2) 92 | end_time2 = np.minimum(wav_duration, center_time + min_window_size/2) 93 | start_sample = int(start_time2*sr) 94 | end_sample = int(end_time2*sr) 95 | else: 96 | start_sample = int(start_time*sr) 97 | end_sample = int(end_time*sr) 98 | 99 | word_wav = wavs[channel_id][start_sample:end_sample] 100 | word_wavs.append(torch.tensor(word_wav, dtype=torch.float32).to(spk_model.device)) 101 | word_lens.append(torch.tensor(word_wav.shape[0], dtype=torch.int).to(spk_model.device)) 102 | with autocast(), torch.no_grad(): 103 | word_wavs = pad_sequence(word_wavs, batch_first=True, padding_value=0) 104 | word_lens = torch.stack(word_lens) 105 | _, tmp_embedding = spk_model.forward(input_signal=word_wavs, input_signal_length=word_lens) 106 | word_embedding.extend(tmp_embedding.cpu().detach()) 107 | 108 | for word, word_embedding_local in zip(words_batch, batch_generator(word_embedding, len(min_embedding_windows))): 109 | words_processed += 1 110 | start_time = word[1] 111 | end_time = word[2] 112 | word_duration = end_time - start_time 113 | if words_processed > n_words: 114 | # This is a dummy word added for DDP. Skip it. 115 | continue 116 | 117 | if word_duration > max_allowed_word_duration: 118 | # Very long word duration is very suspicious and may harm diarization. Ignore them for now. 119 | # Note that these words will disappear in the final result. 120 | # To do: find a better way to deal with these words. 121 | _LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization") 122 | too_long_words.append(word) 123 | continue 124 | 125 | # append only the real words (do not append dummy words) 126 | all_words.append(word+[channel_id]) 127 | all_word_embeddings.append(torch.vstack(word_embedding_local)) 128 | 129 | print(f'Done extracting embeddings. {words_processed=}, {len(all_words)=}, {n_words=}', flush=True) 130 | n_real_words = n_words - len(too_long_words) 131 | assert len(all_words) == n_real_words, f"Number of words {len(all_words)} != n_real_words {n_real_words}" 132 | return all_words, all_word_embeddings 133 | 134 | 135 | def word_based_clustering(audio_files: list, segments_df: pd.DataFrame, cfg: DiarizationCfg, 136 | device: Optional[str] = None): 137 | """ 138 | Treat each ASR word as a segment and run NMESC for clustering. 139 | 140 | Here, we implicitly use ASR as the VAD, and only consider the speech regions that are recognized into 141 | words. For each word, we create a speech segment using the word's time bounaries (start/end times). 142 | These word based speech segments are used as the inputs to clustering. 143 | 144 | As a word's duration is usually too short for extracting reliable speaker embeddings, this function 145 | uses longer windows centered at the word to extract speaker embeddings. 146 | 147 | Motivate by the multi-scale affinity matrixes proposed in NeMo's diarization recipe, this function 148 | also supports multi-scale speaker embedding extraction. Set multiple window sizes in cfg.min_embedding_window. 149 | The affinity matrixes of different scales all have the same weights. 150 | 151 | Note that in NeMo's recipe, larger scale affinity matrix contains fewer elements and resampling is needed 152 | to make the affinity matrixes of all scales having the same size. In this function, all affinity matrixes 153 | have the same size, i.e. NxN, where N is the number of words. So no resampling is needed. 154 | """ 155 | # load unmixed waveforms 156 | srs, wavs = zip(*[read_wav(audio_file, normalize=True, return_rate=True) for audio_file in audio_files]) 157 | sr = srs[0] 158 | max_length = max([wav.size for wav in wavs]) 159 | # pad to the maximum length and stack. padding is only relevant to segmented close-talk. 160 | # CSS always returns equal-length channels. 161 | wavs = np.vstack( 162 | [np.pad(wav, (0, max_length - wav.size), 'constant', constant_values=(0, 0)) for wav in 163 | wavs]) 164 | 165 | # load speaker embedding model 166 | spk_model = load_speaker_model(cfg.embedding_model_name, device=device) 167 | 168 | # extract word-based multi-scale speaker embedding vectors 169 | all_words, all_word_embeddings = extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, 170 | cfg.min_embedding_windows, 171 | cfg.max_allowed_word_duration) 172 | 173 | # compute affinity matrix for clustering 174 | all_word_embeddings2 = torch.stack(all_word_embeddings) 175 | emb_t = all_word_embeddings2.half().to(spk_model.device) 176 | # compute affinity matrix for each scale 177 | scale_affinity = [getCosAffinityMatrix(emb_t[:, scale]) 178 | for scale in range(len(cfg.min_embedding_windows))] 179 | # final affinity matrix is the average of scale-dependent affinity matrices 180 | affinity = torch.mean(torch.stack(scale_affinity), dim=0) 181 | 182 | # run NMESC 183 | cluster_label = run_clustering(affinity) 184 | 185 | # prepare segment data frame 186 | all_words = [word+[f"spk{spk_idx}"] for word, spk_idx in zip(all_words, cluster_label)] 187 | diarized_segments_df = prepare_diarized_data_frame(all_words, segments_df, cfg.apply_deduplication) 188 | 189 | return diarized_segments_df 190 | 191 | 192 | def _fill_dummy_words_for_ddp(segments_df: pd.DataFrame) -> tuple[pd.DataFrame, int]: 193 | """ 194 | Fill the last segment with dummy words to make the number of words the same across all processes in DDP. 195 | 196 | Returns: 197 | (a COPY of segments_df with dummy words added to the last segment, number of real words, number of dummies) 198 | """ 199 | 200 | if not is_dist_initialized(): 201 | return segments_df, 0 202 | 203 | n_words = sum(len(seg['word_timing']) for _, seg in segments_df.iterrows()) 204 | max_words = get_max_value(n_words) 205 | print(f"Number of segments: {len(segments_df)}, Number of words: {n_words}, max_words(in DDP): {max_words}") 206 | 207 | # find first segment with non-empty word_timing 208 | for i in range(len(segments_df)): 209 | if len(segments_df.iloc[i]['word_timing']) > 0: 210 | dummy_word = segments_df.iloc[i]['word_timing'][-1].copy() 211 | break 212 | 213 | # fill last segment with dummy data 214 | _segments_df = segments_df.copy() 215 | n_dummies = max_words - n_words 216 | for _ in range(n_dummies): 217 | _segments_df.iloc[-1]['word_timing'].append(dummy_word) 218 | 219 | n_words_with_dummies = sum([len(seg['word_timing']) for _, seg in _segments_df.iterrows()]) 220 | assert n_words_with_dummies == max_words, \ 221 | f"Number of words with dummies {n_words_with_dummies} != max_words {max_words}" 222 | print(f"Number of words to process (with dummies): {n_words_with_dummies}") 223 | 224 | return _segments_df, n_dummies 225 | -------------------------------------------------------------------------------- /inference_pipeline/inference.py: -------------------------------------------------------------------------------- 1 | from dataclasses import field, dataclass 2 | from functools import partial 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import tqdm 7 | import pandas as pd 8 | 9 | from asr.asr import asr_inference, WhisperAsrCfg 10 | from css.css import css_inference, CssCfg 11 | from diarization.diarization import diarization_inference 12 | from diarization.diarization_common import DiarizationCfg 13 | from inference_pipeline.load_meeting_data import load_data 14 | from utils.logging_def import get_logger 15 | from utils.scoring import ScoringCfg, calc_wer, df_to_seglst, normalize_segment, write_submission_jsons 16 | 17 | _LOG = get_logger('inference') 18 | 19 | 20 | @dataclass 21 | class InferenceCfg: 22 | css: CssCfg = field(default_factory=CssCfg) 23 | asr: WhisperAsrCfg = field(default_factory=WhisperAsrCfg) 24 | diarization: DiarizationCfg = field(default_factory=DiarizationCfg) 25 | scoring: ScoringCfg = field(default_factory=ScoringCfg) 26 | # Optional: Query to filter all_session_df. Useful for debugging. Must be None during full evaluation. 27 | session_query: Optional[str] = None 28 | 29 | 30 | @dataclass 31 | class FetchFromCacheCfg: 32 | css: bool = False 33 | asr: bool = False 34 | diarization: bool = False 35 | 36 | 37 | def inference_pipeline(meetings_dir: str, models_dir: str, out_dir: str, cfg: InferenceCfg, 38 | cache: FetchFromCacheCfg): 39 | """ 40 | Run the inference pipeline on sessions loaded from meetings_dir. 41 | 42 | Args: 43 | meetings_dir: directory with meeting data. 44 | example: project_root/artifacts/meeting_data/dev_set/240121_dev/MTG/ 45 | models_dir: directory with CSS models. 46 | example: project_root/artifacts/css_models/ 47 | out_dir: modules will write their outputs here. 48 | cfg: config per module. 49 | cache: basic cache mechanism to re-use results per module. Off by default. 50 | Note: use at your own risk. If you modify code or config, make sure to delete the cache 51 | or set to False. 52 | """ 53 | # Load all meetings from the meetings dir 54 | _LOG.info(f'loading meetings from: {meetings_dir}') 55 | all_session_df, all_gt_utt_df, all_gt_metadata_df = load_data(meetings_dir, cfg.session_query) 56 | 57 | wer_dfs, hyp_jsons = [], [] 58 | # Process each session independently. (Cross-session information is not permitted) 59 | for _, session in tqdm.tqdm(all_session_df.iterrows(), desc='processing sessions'): 60 | _LOG.info(f'Processing session: {session.session_id}') 61 | 62 | # Front-end: split session into enhanced streams without overlap speech 63 | session: pd.Series = css_inference(out_dir, models_dir, session, cfg.css, cache.css) 64 | 65 | # Run ASR on each stream and return transcribed segments 66 | segments_df: pd.DataFrame = asr_inference(out_dir, session, cfg.asr, cache.asr) 67 | 68 | # Return speaker attributed segments (re-segmentation can occur) 69 | attributed_segments_df: pd.DataFrame = ( 70 | diarization_inference(out_dir, segments_df, cfg.diarization, cache.diarization)) 71 | 72 | # Write hypothesis transcription to: outdir / wer / {multi|single}channel / session_id / *.json 73 | # These will be merged into one json per track (mc/sc) for submission below. 74 | hyp_paths: pd.Series = write_hypothesis_jsons( 75 | out_dir, session, attributed_segments_df, cfg.asr.text_normalizer()) 76 | hyp_jsons.append(hyp_paths) 77 | 78 | # Calculate session WER if GT is available 79 | if all_gt_utt_df is not None: 80 | # Rules: WER metric, arguments (collar), and text normalizer must remain unchanged 81 | calc_wer_out = Path(out_dir) / 'wer' / session.session_id 82 | session_wer: pd.DataFrame = calc_wer( 83 | calc_wer_out, 84 | hyp_paths.tcp_wer_hyp_json, 85 | hyp_paths.tcorc_wer_hyp_json, 86 | all_gt_utt_df, 87 | cfg.asr.text_normalizer(), 88 | collar=5, save_visualizations=cfg.scoring.save_visualizations) 89 | wer_dfs.append(session_wer) 90 | 91 | # To submit results to one of the tracks, upload the tcp_wer_hyp.json and tc_orc_wer_hyp.json located in: 92 | # outdir/wer/{singlechannel | multichannel}/ 93 | hyp_jsons_df = pd.DataFrame(hyp_jsons) 94 | write_submission_jsons(out_dir, hyp_jsons_df) 95 | 96 | if wer_dfs: # GT available 97 | all_session_wer_df = pd.concat(wer_dfs, ignore_index=True) 98 | _LOG.info(f'Results:\n{all_session_wer_df}') 99 | _LOG.info(f'mean tcp_wer = {all_session_wer_df["tcp_wer"].mean()}') 100 | _LOG.info(f'mean tcorc_wer = {all_session_wer_df["tcorc_wer"].mean()}') 101 | 102 | # write session level results into a file 103 | exp_id = "_".join(['css', cfg.asr.model_name, cfg.diarization.method]) 104 | result_file = Path(out_dir) / "wer" / f"{exp_id}_results.csv" 105 | result_file.parent.mkdir(parents=True, exist_ok=True) 106 | all_session_wer_df.to_csv(result_file, sep="\t") 107 | _LOG.info(f"Wrote full results to: {result_file}") 108 | # TODO confidence intervals, WER per meta-data 109 | 110 | 111 | def write_hypothesis_jsons(out_dir, session: pd.Series, 112 | attributed_segments_df: pd.DataFrame, 113 | text_normalizer): 114 | """ 115 | Write hypothesis transcripts for session, to be used for tcpwer and tcorwer metrics. 116 | """ 117 | 118 | _LOG.info(f'Writing hypothesis transcripts for session {session.session_id}') 119 | 120 | def write_json(df, filename): 121 | filepath = Path(out_dir) / 'wer' / session.session_id / filename 122 | filepath.parent.mkdir(parents=True, exist_ok=True) 123 | seglst = df_to_seglst(df) 124 | seglst = seglst.map(partial(normalize_segment, tn=text_normalizer)) 125 | seglst.dump(filepath) 126 | _LOG.info(f'Wrote {filepath}') 127 | return filepath 128 | 129 | # I. hyp file for tcpWER 130 | tcp_wer_hyp_json = write_json(attributed_segments_df, 'tcp_wer_hyp.json') 131 | 132 | # II. hyp file for tcORC-WER, a supplementary metric for analysis. 133 | # meeteval.wer.tcorcwer requires a stream ID, which depends on the system. 134 | # Overlapped words should go into different streams, or appear in one stream while respecting the order 135 | # in reference. See https://github.com/fgnt/meeteval. 136 | # In NOTSOFAR we define the streams as the outputs of CSS (continuous speech separation). 137 | # If your system does not have CSS you need to define the streams differently. 138 | # For example: for end-to-end multi-talker ASR you might use a single stream. 139 | # Alternatively, you could use the predicted speaker ID as the stream ID. 140 | 141 | # The wav_file_name column of attributed_segments_df indicates the source CSS stream. 142 | # Note that the diarization module ensures the words within each segment have a consistent channel. 143 | df_tcorc = attributed_segments_df.copy() 144 | # Use factorize to map each unique wav_file_name to an index. 145 | # meeteval.wer.tcorcwer treats speaker_id field as stream id. 146 | df_tcorc['speaker_id'], uniques = pd.factorize(df_tcorc['wav_file_name'], sort=True) 147 | _LOG.debug(f'Found {len(uniques)} streams for tc_orc_wer_hyp.stm') 148 | tcorc_wer_hyp_json = write_json(df_tcorc, 'tc_orc_wer_hyp.json') 149 | 150 | return pd.Series({ 151 | 'session_id': session.session_id, 152 | 'tcp_wer_hyp_json': tcp_wer_hyp_json, 153 | 'tcorc_wer_hyp_json': tcorc_wer_hyp_json, 154 | 'is_mc': session.is_mc, 155 | 'is_close_talk': session.is_close_talk, 156 | }) 157 | -------------------------------------------------------------------------------- /inference_pipeline/load_meeting_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Tuple, Optional 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import soundfile 8 | from tqdm import tqdm 9 | 10 | from utils.audio_utils import write_wav 11 | from utils.torch_utils import is_zero_rank, barrier 12 | 13 | 14 | def load_data(meetings_dir: str, session_query: Optional[str] = None, 15 | return_close_talk: bool = False, out_dir: Optional[str] = None 16 | ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: 17 | """ 18 | Load all meetings from the meetings dir 19 | 20 | Args: 21 | meetings_dir: directory containing meetings. 22 | Example: project_root/artifacts/meeting_data/dev_set/240121_dev/MTG/ 23 | session_query: a query string to filter the sessions (optional) 24 | When submitting results, this should be None so no filtering occurs. 25 | return_close_talk: if True, return each meeting as a session with all close-talk devices as its 26 | wav_file_names. 27 | Close-talk must not be used during inference. However, this can be used as supervision 28 | signal during training or for analysis. 29 | out_dir: directory to save outputs to. only used when return_close_talk is True. 30 | Returns: 31 | all_session_df (per device): 32 | Each line corresponds to a recording of a meeting captured with a single device 33 | (referred to as a 'session'). 34 | If a meeting was recorded with N devices (single or multi-channel), the DataFrame should contain 35 | N lines – one for every device recording. 36 | Rules: 37 | - Inference must run independently for each session (device) and no cross-session information 38 | is permitted. 39 | - Use of close-talk microphones is not permitted during inference. 40 | all_gt_utt_df (per utt): 41 | each line is a ground truth utterance 42 | all_gt_metadata_df (per meeting): 43 | each line is a meeting's metadata: participants, topics, 44 | hashtags (#WalkAndTalk, #TalkNearWhiteboard etc. useful for analysis) and more. 45 | """ 46 | meetings_dir = Path(meetings_dir) 47 | 48 | # list to store dataframes for each meeting 49 | gt_utt_dfs = [] 50 | session_dfs = [] 51 | metadata_dfs = [] 52 | 53 | sorted_dirs = sorted(meetings_dir.glob('*/')) 54 | for meeting_subdir in tqdm(sorted_dirs, desc='loading meetings data'): 55 | if not meeting_subdir.is_dir(): 56 | continue 57 | transcription_file = meeting_subdir / 'gt_transcription.json' 58 | devices_file = meeting_subdir / 'devices.json' 59 | metadata_file = meeting_subdir / 'gt_meeting_metadata.json' 60 | 61 | gt_utt_df = None 62 | if transcription_file.exists(): 63 | # we have GT transcription 64 | gt_utt_df = pd.read_json(transcription_file) 65 | # add a 'meeting_id' column 66 | gt_utt_df['meeting_id'] = meeting_subdir.name 67 | gt_utt_dfs.append(gt_utt_df) 68 | 69 | if metadata_file.exists(): 70 | with open(metadata_file, 'r') as file: 71 | metadata = json.load(file) 72 | metadata_df = pd.DataFrame([metadata]) 73 | metadata_dfs.append(metadata_df) 74 | 75 | devices_df = pd.read_json(devices_file) 76 | devices_df['meeting_id'] = meeting_subdir.name 77 | if return_close_talk: 78 | devices_df = devices_df[devices_df.is_close_talk].copy() 79 | assert len(devices_df) > 0, 'no close-talk devices found' 80 | assert gt_utt_df is not None, 'expecting GT transcription' 81 | 82 | new_wav_file_names = concat_speech_segments(devices_df, gt_utt_df, meeting_subdir, out_dir) 83 | 84 | # original close-talk: 85 | # orig_wav_file_names = devices_df.wav_file_names.apply(lambda x: str(meeting_subdir / x)).to_list() 86 | 87 | devices_df = devices_df.iloc[0:1].copy() 88 | devices_df['device_name'] = 'close_talk' 89 | devices_df['wav_file_names'] = [new_wav_file_names] # orig_wav_file_names 90 | devices_df['session_id'] = 'close_talk/' + meeting_subdir.name 91 | else: 92 | # drop close-talk devices 93 | devices_df = devices_df[~devices_df.is_close_talk].copy() 94 | 95 | prefix = devices_df.is_mc.map({True: 'multichannel', False: 'singlechannel'}) 96 | devices_df['session_id'] = prefix + '/' + meeting_subdir.name + '_' + devices_df['device_name'] 97 | # convert to a list of full paths by appending meeting_subdir to each file in wav_file_name 98 | devices_df['wav_file_names'] = devices_df['wav_file_names'].apply( 99 | lambda x: [str(meeting_subdir / file_name.strip()) for file_name in x.split(',')] 100 | ) 101 | 102 | session_dfs.append(devices_df) 103 | 104 | 105 | # concatenate all meetings into one big DataFrame 106 | all_gt_utt_df = pd.concat(gt_utt_dfs, ignore_index=True) if gt_utt_dfs else None 107 | all_session_df = pd.concat(session_dfs, ignore_index=True) 108 | all_metadata_df = pd.concat(metadata_dfs, ignore_index=True) if metadata_dfs else None 109 | 110 | # MtgType column is useful for querying, but it is on the metadata df. merge it into session df. 111 | if all_metadata_df is not None: 112 | merged_df = all_session_df.merge(all_metadata_df[['meeting_id', 'MtgType']], 113 | on='meeting_id', how='inner') 114 | assert len(merged_df) == len(all_session_df) 115 | assert not merged_df.MtgType.isna().any(), 'expecting valid MtgType values' 116 | all_session_df = merged_df 117 | assert not all_session_df.MtgType.str.startswith("read").any(), \ 118 | '"read" meetings are for debug, they are not expected here' 119 | # avoid using MtgType from here on 120 | all_session_df.drop('MtgType', axis=1, inplace=True) 121 | 122 | if session_query: 123 | query, process_first_n = _process_query(session_query) 124 | all_session_df.query(query, inplace=True) 125 | if process_first_n: 126 | all_session_df = all_session_df.head(process_first_n) 127 | 128 | return all_session_df, all_gt_utt_df, all_metadata_df 129 | 130 | 131 | def _process_query(query): 132 | """ Split query into a few parts 133 | Query can have the following format: 134 | 1. "query_string" 135 | 2. "query_string ##and index Path: 14 | """ Returns project root folder """ 15 | return Path(__file__).parent 16 | 17 | 18 | def load_config(config_name: ConfigName) -> InferenceCfg: 19 | """ Returns the config file path and session query for the given config name """ 20 | project_root = get_project_root() 21 | 22 | updates = {} 23 | if config_name == 'full_dev_set_mc': 24 | # all multi-channel (MC) dev-set sessions 25 | conf_file = project_root / 'configs/inference/inference_v1.yaml' 26 | session_query = "is_mc == True" # filter only MC 27 | 28 | elif config_name == 'full_dev_set_sc': 29 | # all single-channel (SC) dev-set sessions 30 | conf_file = project_root / 'configs/inference/inference_v1.yaml' 31 | session_query = "is_mc == False" # filter only SC 32 | 33 | elif config_name == 'dev_set_mc_debug': 34 | # for quick debug: 'tiny' Whisper, one MC (multi-channel) session 35 | conf_file = project_root / 'configs/inference/debug_inference.yaml' 36 | session_query = 'device_name == "plaza_0" and is_mc == True and meeting_id == "MTG_30860"' 37 | 38 | else: 39 | raise ValueError(f'unknown config name: {config_name}') 40 | 41 | cfg: InferenceCfg = load_yaml_to_dataclass(str(conf_file), InferenceCfg) 42 | cfg = update_dataclass(cfg, updates) 43 | 44 | if session_query is not None: 45 | assert cfg.session_query is None, 'overriding session_query from yaml' 46 | cfg.session_query = session_query 47 | 48 | return cfg 49 | 50 | 51 | def main(config_name: ConfigName = 'dev_set_mc_debug', output_dir: str = ""): 52 | project_root = get_project_root() 53 | cfg: InferenceCfg = load_config(config_name) 54 | 55 | # download the entire dev-set (all sessions, multi-channel and single-channel) 56 | meetings_root = project_root / 'artifacts' / 'meeting_data' 57 | dev_meetings_dir = download_meeting_subset(subset_name='dev_set', 58 | version='240825.1_dev1', # dev-set-1, GT included 59 | destination_dir=str(meetings_root)) 60 | 61 | if dev_meetings_dir is None: 62 | raise RuntimeError('failed to download benchmark dataset') 63 | 64 | # download models 65 | model_set_type = 'css-models' 66 | models_dir = project_root / 'artifacts' 67 | download_models(destination_dir=str(models_dir), set_type=model_set_type) 68 | 69 | # outputs per module will be written here 70 | outputs_dir = (project_root if output_dir == "" else Path(output_dir)) / 'artifacts' / 'outputs' 71 | 72 | cache_cfg = FetchFromCacheCfg() # no cache, use this at your own risk. 73 | 74 | exp_name = ('pass_through' if cfg.css.pass_through_ch0 else 'css') + '_' + cfg.asr.model_name 75 | outputs_dir = outputs_dir / exp_name 76 | 77 | pprint(f'{config_name=}') 78 | pprint(cfg) 79 | 80 | # run inference pipeline 81 | inference_pipeline(meetings_dir=str(dev_meetings_dir), 82 | models_dir=str(models_dir / model_set_type), 83 | out_dir=str(outputs_dir), 84 | cfg=cfg, 85 | cache=cache_cfg) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description='Run inference pipeline') 90 | parser.add_argument('--config-name', type=str, default="dev_set_mc_debug", 91 | help='Config scenario for the inference, default: dev_set_mc_debug') 92 | parser.add_argument('--output-dir', type=str, default="", 93 | help='Output directory path, default: ./artifacts/outputs') 94 | args = parser.parse_args() 95 | 96 | main(args.config_name, args.output_dir) 97 | -------------------------------------------------------------------------------- /run_training_css_local.py: -------------------------------------------------------------------------------- 1 | from css.training.train import main 2 | 3 | if __name__ == '__main__': 4 | main() 5 | -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_noise: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_noise -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_activity_scores: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_activity_scores -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_direct_early_echoes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_direct_early_echoes -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_reverb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.gt_spk_reverb -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.json: -------------------------------------------------------------------------------- 1 | {"index_name": null, "index_value": "0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53", "columns": {"gt_spk_activity_scores": {"dtype": "int8", "shape": [160000, 3], "itemsize": 1, "element_in_row": 3, "row_size": 3}, "mixture": {"dtype": "int16", "shape": [160000, 7], "itemsize": 2, "element_in_row": 7, "row_size": 14}, "mixture_scale": {"values": "491520.0"}, "gt_spk_direct_early_echoes": {"dtype": "int16", "shape": [160000, 7, 3], "itemsize": 2, "element_in_row": 21, "row_size": 42}, "gt_spk_direct_early_echoes_scale": {"values": "524288.0"}, "gt_spk_reverb": {"dtype": "int16", "shape": [160000, 7, 3], "itemsize": 2, "element_in_row": 21, "row_size": 42}, "gt_spk_reverb_scale": {"values": "1310720.0"}, "gt_noise": {"dtype": "int16", "shape": [160000, 7], "itemsize": 2, "element_in_row": 7, "row_size": 14}, "gt_noise_scale": {"values": "10452992.0"}}} -------------------------------------------------------------------------------- /sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.mixture: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/NOTSOFAR1-Challenge/6f58e08b008f7530ba4141f0aeb02447c70b6fd7/sample_data/css_train_set/0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53.mixture -------------------------------------------------------------------------------- /sample_data/css_train_set/dataset-000000.map: -------------------------------------------------------------------------------- 1 | {"0000_Libri.clean-500.book_00082_chp_0009_reader_02124_53": 160000} -------------------------------------------------------------------------------- /utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import soundfile as sf 4 | import scipy.io.wavfile as wf 5 | 6 | MAX_INT16 = np.iinfo(np.int16).max 7 | EPSILON = np.finfo(np.float32).eps 8 | 9 | 10 | def read_wav(fname, beg=None, end=None, normalize=True, return_rate=False): 11 | """ 12 | Read wave files using scipy.io.wavfile(support multi-channel) 13 | """ 14 | # samps_int16: N x C or N 15 | # N: number of samples 16 | # C: number of channels 17 | if beg is not None: 18 | samps_int16, samp_rate = sf.read(fname, 19 | start=beg, 20 | stop=end, 21 | dtype="int16") 22 | else: 23 | samp_rate, samps_int16 = wf.read(fname) 24 | # N x C => C x N 25 | samps = samps_int16.astype(np.float32) 26 | # tranpose because I used to put channel axis first 27 | if samps.ndim != 1: 28 | samps = np.transpose(samps) 29 | # normalize like MATLAB and librosa 30 | if normalize: 31 | samps = samps / MAX_INT16 32 | if return_rate: 33 | return samp_rate, samps 34 | return samps 35 | 36 | 37 | def write_wav(fname, samps: np.ndarray, sr=16000, max_norm: bool = True): 38 | """ 39 | Write wav to file 40 | 41 | max_norm: normalize to [-1, 1] to avoid potential overflow. 42 | """ 43 | assert samps.ndim == 1 44 | if max_norm: 45 | samps = samps * 0.99 / (np.max(np.abs(samps)) + 1e-7) 46 | 47 | dir_name = os.path.dirname(fname) 48 | os.makedirs(dir_name, exist_ok=True) 49 | sf.write(fname, samps, sr) 50 | 51 | 52 | def play_wav(wav: np.ndarray, fs: int = 16000, volume_factor: float = 1.): 53 | import sounddevice as sd 54 | numpy_audio = wav.squeeze() 55 | sd.play(numpy_audio * volume_factor, fs) -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Type, Dict, Union 2 | from pathlib import Path 3 | import argparse 4 | from dataclasses import dataclass, field 5 | 6 | from omegaconf import OmegaConf 7 | 8 | 9 | ConfT = TypeVar('ConfT') 10 | 11 | 12 | def load_yaml_to_dataclass(yaml_path: Union[str, Path], conf_type: Type[ConfT]) -> ConfT: 13 | """ 14 | Load a YAML file and convert it to a dataclass object. 15 | 16 | Example: 17 | cfg: InferenceCfg = get_conf(conf_file, InferenceCfg) 18 | """ 19 | schema = OmegaConf.structured(conf_type) 20 | conf = OmegaConf.load(yaml_path) 21 | merged = OmegaConf.merge(schema, conf) # this will override schema with values from conf 22 | return OmegaConf.to_object(merged) 23 | 24 | 25 | def update_dataclass(dataclass_obj: ConfT, updates: Dict) -> ConfT: 26 | """ 27 | Update values in dataclass config using either dot-notation or brackets to denote sub-keys 28 | """ 29 | schema = OmegaConf.structured(dataclass_obj) 30 | for k,v in updates.items(): 31 | OmegaConf.update(schema, k, v) 32 | return OmegaConf.to_object(schema) 33 | 34 | 35 | def _demo(): 36 | @dataclass 37 | class CssConf: 38 | lr: float = 0.001 39 | epochs: int = 100 40 | 41 | @dataclass 42 | class Conf: 43 | css: CssConf = field(default_factory=CssConf) 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--verb', choices=['show', 'write-default'], default='show') 47 | parser.add_argument('--yaml_path', default='../configs/conf_demo.yaml') 48 | args = parser.parse_args() 49 | 50 | if args.verb == 'show': 51 | c: Conf = load_yaml_to_dataclass(args.yaml_path, Conf) 52 | print(c) 53 | 54 | elif args.verb == 'write-default': 55 | schema = OmegaConf.structured(Conf) 56 | OmegaConf.save(config=schema, f=args.yaml_path) 57 | print(f'Default config was written to {args.yaml_path}') 58 | 59 | else: 60 | raise ValueError(f'Unknown verb: {args.verb}') 61 | 62 | 63 | if __name__ == '__main__': 64 | _demo() 65 | -------------------------------------------------------------------------------- /utils/hugging_face_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, wait 3 | from pathlib import Path 4 | from typing import Union, Optional, List 5 | 6 | from tqdm import tqdm 7 | from huggingface_hub import HfApi 8 | 9 | 10 | # Constants 11 | NOTSOFAR_HF_REPO_ID = "microsoft/NOTSOFAR" 12 | 13 | # Initialize Hugging Face API 14 | hugging_face_token = os.getenv('HF_TOKEN') 15 | assert hugging_face_token, ("HuggingFace token not found. Please set the HF_TOKEN environment variable, " 16 | "if you have set it, please restart the session. " 17 | "Use README.md (NOTSOFAR-1 Datasets - Download Instructions) for more information.") 18 | _HF_API = HfApi(token=hugging_face_token) 19 | 20 | 21 | def is_hf_dir_exists(subfolder: Union[str, Path]) -> bool: 22 | """ 23 | Check if a subfolder exists in the Hugging Face repository 24 | 25 | Args: 26 | subfolder: path in the repository to check 27 | 28 | Returns: 29 | bool: True if the subfolder exists, False otherwise 30 | """ 31 | assert isinstance(subfolder, (str, Path)), "local_dir should be a string or Path object" 32 | 33 | try: 34 | files = _HF_API.list_repo_tree(repo_id=NOTSOFAR_HF_REPO_ID, repo_type="dataset", path_in_repo=subfolder) 35 | return True if files else False 36 | except Exception as e: 37 | return False 38 | 39 | 40 | def list_hf_dir(root_dir: Union[str, Path], recursive: bool = False) -> List[str]: 41 | """ 42 | List all files in a specified directory in the Hugging Face repository. 43 | 44 | Args: 45 | root_dir (Union[str, Path]): The root directory in the repository to list files from. 46 | recursive (bool): Whether to list files recursively. Defaults to False. 47 | 48 | Returns: 49 | List[str]: A list of file paths in the specified directory. 50 | 51 | Raises: 52 | AssertionError: If the specified directory does not exist in the repository. 53 | """ 54 | root_dir = str(root_dir) 55 | assert is_hf_dir_exists(root_dir), f"Cannot find {root_dir} in the Hugging Face repository" 56 | 57 | try: 58 | return [val.path for val in _HF_API.list_repo_tree( 59 | repo_id=NOTSOFAR_HF_REPO_ID, repo_type="dataset", path_in_repo=root_dir, recursive=recursive)] 60 | except Exception as e: 61 | raise RuntimeError(f"Failed to list directory {root_dir} in the Hugging Face repository: {e}") 62 | 63 | 64 | def list_hf_dir_files(root_dir: Union[str, Path]) -> List[str]: 65 | """ 66 | List all files (excluding directories) in a specified directory recursively in the Hugging Face repository. 67 | 68 | Args: 69 | root_dir (Union[str, Path]): The root directory in the repository to list files from. 70 | 71 | Returns: 72 | List[str]: A list of file paths in the specified directory. 73 | """ 74 | def _is_file(file_path: str) -> bool: 75 | return '.' in os.path.basename(file_path) 76 | 77 | try: 78 | return [dir_file_path for dir_file_path in list_hf_dir(root_dir, recursive=True) if _is_file(dir_file_path)] 79 | except Exception as e: 80 | raise RuntimeError(f"Failed to list files in directory {root_dir} in the Hugging Face repository: {e}") 81 | 82 | 83 | def download_hf_file(file_path: str, local_dir: str, pbar: Optional[tqdm] = None) -> str: 84 | """ 85 | Download a file from the Hugging Face repository. 86 | 87 | Args: 88 | file_path (str): Path of the file in the repository. 89 | local_dir (Path): Local directory to download the file to. 90 | pbar (Optional[tqdm]): tqdm progress bar object (optional). 91 | 92 | Returns: 93 | str: Local file path where the file is downloaded. 94 | 95 | Raises: 96 | RuntimeError: If the file download fails. 97 | """ 98 | local_file_path = Path(local_dir) / file_path 99 | os.makedirs(local_file_path.parent, exist_ok=True) 100 | 101 | try: 102 | _HF_API.hf_hub_download(repo_id=NOTSOFAR_HF_REPO_ID, filename=file_path, 103 | repo_type="dataset", local_dir=local_dir) 104 | if pbar: 105 | pbar.update(1) # Increment progress bar if provided 106 | except Exception as e: 107 | raise RuntimeError(f"Failed to download file {file_path} from the Hugging Face repository: {e}") 108 | 109 | return str(local_file_path) 110 | 111 | 112 | def download_hf_dir(subfolder: str, local_dir: Union[str, Path], max_workers: int = os.cpu_count()) -> List[str]: 113 | """ 114 | Download all files in a subfolder of the Hugging Face repository. 115 | 116 | Args: 117 | subfolder (str): Path in the repository to download. 118 | local_dir (Union[str, Path]): Local directory to download the files to. 119 | max_workers (int): Number of workers to use for downloading, defaults to number of CPUs. 120 | 121 | Returns: 122 | List[str]: A list of file paths downloaded to the local directory. 123 | 124 | Raises: 125 | AssertionError: If the subfolder does not exist in the repository. 126 | ValueError: If the local_dir is not a string or Path object. 127 | ValueError: If max_workers is not an integer. 128 | RuntimeError: If downloading files fails. 129 | """ 130 | if not is_hf_dir_exists(subfolder): 131 | raise AssertionError(f"Subfolder {subfolder} does not exist in the Hugging Face repository") 132 | if not isinstance(local_dir, (str, Path)): 133 | raise ValueError("local_dir should be a string or Path object") 134 | if not isinstance(max_workers, int): 135 | raise ValueError("max_workers should be an integer") 136 | 137 | local_dir = str(local_dir) 138 | files = list_hf_dir_files(root_dir=subfolder) 139 | 140 | with tqdm(total=len(files), desc="Downloading", unit="file") as pbar: 141 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 142 | futures = [executor.submit(download_hf_file, file, local_dir, pbar) for file in files] 143 | wait(futures) 144 | return files 145 | 146 | 147 | def main(): 148 | """ 149 | Usage example for the Hugging Face helper functions 150 | """ 151 | import tempfile 152 | 153 | # List directory 154 | print("\n>>> Listing directory") 155 | folder_path = "benchmark-datasets/dev_set" 156 | dirs = list_hf_dir(root_dir=folder_path) 157 | print(f"Directories in {folder_path}: {dirs}") 158 | 159 | # List files in a directory 160 | print("\n>>> Listing files in a directory recursively") 161 | folder_path = "benchmark-datasets/dev_set" 162 | files = list_hf_dir_files(root_dir=folder_path) 163 | print(f"Files in {folder_path}: {files}") 164 | print(f"Number of files in {folder_path}: {len(files)}") 165 | 166 | # Check if a directory exists 167 | print("\n>>> Checking if a directory exists") 168 | subfolder = "benchmark-datasets/dev_set/240130.1_dev/MTG/MTG_30860/mc_plaza_0" 169 | print(f"Does {subfolder} exist? {is_hf_dir_exists(subfolder)}") 170 | 171 | # Download a directory 172 | print("\n>>> Downloading a directory") 173 | download_dir = 'benchmark-datasets/dev_set/240130.1_dev/MTG/MTG_30860/mc_plaza_0' 174 | with tempfile.TemporaryDirectory() as temp_dir: 175 | print(f"Downloading {download_dir} to {temp_dir}") 176 | downloaded_files_path = download_hf_dir(subfolder=download_dir, local_dir=temp_dir) 177 | print(f"Downloaded files: {downloaded_files_path}") 178 | print(f"Number of downloaded files: {len(downloaded_files_path)}") 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /utils/logging_def.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | 5 | # this must be called before any other loggers are instantiated to take effect 6 | logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] [%(name)s] %(message)s') 7 | 8 | # verbose pandas display 9 | pd.set_option('display.precision', 4) 10 | pd.options.display.width = 600 11 | pd.options.display.max_columns = 20 12 | pd.options.display.max_rows = 200 13 | # _LOG.info('display options:\n%s', pprint.pformat(pd.options.display.__dict__, indent=4)) 14 | 15 | 16 | def get_logger(name: str): 17 | """ 18 | All modules should use this function to get a logger. 19 | This way, we ensure all loggers are instantiated after basicConfig() call and inherit the same config. 20 | """ 21 | return logging.getLogger(name) -------------------------------------------------------------------------------- /utils/mic_array_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def multichannel_mic_pos_xyz_cm() -> np.ndarray: 5 | """ 6 | Returns the mic positions in cm of multichannel devices used in NOTSOFAR. 7 | The order matches the wav files of multichannel sessions. 8 | 9 | Returns: 10 | mic_pos_xyz_cm: (7, 3) array of mic positions in cm. 11 | mic_pos_xyz_cm[0, :] is the center microphone's x,y,z. 12 | 13 | """ 14 | # TODO: finalize these numbers 15 | mic1_az = 0. 16 | az_dir_sign = 1. 17 | 18 | mic_pos_xyz_cm = np.empty((7, 3)) 19 | mic_pos_xyz_cm[:] = np.nan 20 | mic_pos_xyz_cm[0, :] = 0. 21 | r = 4.25 22 | for i in range(1, 7): 23 | mic_pos_xyz_cm[i, 0] = r * np.cos(np.deg2rad(az_dir_sign * 60. * (i - 1) + mic1_az)) 24 | mic_pos_xyz_cm[i, 1] = r * np.sin(np.deg2rad(az_dir_sign * 60. * (i - 1) + mic1_az)) 25 | mic_pos_xyz_cm[i, 2] = 0. 26 | assert not np.any(np.isnan(mic_pos_xyz_cm)) 27 | return mic_pos_xyz_cm 28 | 29 | -------------------------------------------------------------------------------- /utils/notsofar_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions to download the NOTSOFAR dataset and models. 3 | """ 4 | import os 5 | import shutil 6 | import logging 7 | import tempfile 8 | from pathlib import Path 9 | from typing import Union, Optional, Literal 10 | 11 | from utils.logging_def import get_logger 12 | from utils.hugging_face_helper import download_hf_dir 13 | 14 | _LOG = get_logger('notsofar_dataset') 15 | 16 | 17 | def download_meeting_subset(subset_name: Literal['train_set', 'dev_set', 'eval_set'], 18 | version: str, destination_dir: Union[str, Path], 19 | overwrite: bool = False) -> Optional[str]: 20 | """ 21 | Downloads a subset of the NOTSOFAR recorded meeting dataset. 22 | 23 | The subsets will be released according to the timeline in: 24 | https://www.chimechallenge.org/current/task2/index#dates 25 | 26 | Args: 27 | subset_name: name of split to download (dev_set / eval_set / train_set) 28 | version: version to download (240103g / etc.). it's best to use the latest. 29 | destination_dir: path to the directory where files will be downloaded. 30 | overwrite: whether to override the output file if it already exists 31 | (warning!: if true, will delete the entire destination_dir if it exists) 32 | 33 | 34 | Latest available versions: 35 | 36 | # dev-set-2, no GT available. Submit your systems to leaderboard to measure WER. 37 | # dev-set-2 includes mostly new participants compared to the training sets and dev-set-1. 38 | res_dir = download_meeting_subset(subset_name='dev_set', version='240415.2_dev', destination_dir=...) 39 | 40 | # training set: first and second train-set batches and dev-set-1 (GT unveiled) combined. 41 | # dev-set-1 and the training sets have significant participant overlap. Use dev-set-2 for development. 42 | res_dir = download_meeting_subset(subset_name='train_set', version='240501.1_train', destination_dir=...) 43 | 44 | 45 | Previous versions: 46 | 47 | # this dataset is identical to the updated "240501.1_train" except it includes some faulty multi-channel 48 | # devices with replicated channels that have been removed in the newer version. 49 | res_dir = download_meeting_subset(subset_name='train_set', version='240415.1_train', destination_dir=...) 50 | 51 | 52 | # dev-set-1, no GT available. Previous leaderboard was used to measure WER. 53 | res_dir = download_meeting_subset(subset_name='dev_set', version='240208.2_dev', destination_dir=...) 54 | 55 | # first and second train-set batches combined, with GT for training models. 56 | res_dir = download_meeting_subset(subset_name='train_set', version='240229.1_train', destination_dir=...) 57 | 58 | # first train-set batch, with GT for training models. 59 | res_dir = download_meeting_subset(subset_name='train_set', version='240208.2_train', destination_dir=...) 60 | 61 | 62 | Returns: 63 | a string indicates the output directory path, or None if the download failed 64 | """ 65 | set_type = 'benchmark-datasets' 66 | _LOG.info(f'Downloading {set_type} subset: {subset_name}, version: {version}') 67 | 68 | destination_dir = Path(destination_dir) 69 | if overwrite and destination_dir.exists(): 70 | shutil.rmtree(destination_dir) 71 | 72 | hf_subfolder = f'{set_type}/{subset_name}/{version}/MTG' 73 | download_hf_dir(subfolder=hf_subfolder, local_dir=destination_dir) 74 | _LOG.info(f'Download completed, download dir: {destination_dir}') 75 | 76 | local_dir = destination_dir / hf_subfolder 77 | return str(local_dir) if local_dir.exists() else None 78 | 79 | 80 | def download_simulated_subset(version: str, volume: Literal['200hrs', '1000hrs'], 81 | subset_name: Literal['train', 'val'], destination_dir: str, 82 | overwrite: bool = False) -> Optional[str]: 83 | """ 84 | Download the simulated dataset to the destination directory 85 | Args: 86 | version: version of the train data to download (v1 / v1.1 / v1.2 / v1.3 / etc.) 87 | volume: volume of the train data to download (200hrs / 1000hrs) 88 | subset_name: train data type to download (train / val) 89 | destination_dir: path to the directory where files will be downloaded. 90 | overwrite: whether to override the output file if it already exists 91 | (warning!: if true, will delete the entire destination_dir if it exists) 92 | 93 | 94 | Latest available datasets: 95 | 96 | # 1000 hours 97 | train_set_path = download_simulated_subset(version='v1.5', volume='1000hrs', subset_name='train', 98 | destination_dir=...) 99 | val_set_path = download_simulated_subset(version='v1.5', volume='1000hrs', subset_name='val', 100 | destination_dir=...) 101 | 102 | # 200 hours subset 103 | train_set_path = download_simulated_subset(version='v1.5', volume='200hrs', subset_name='train', 104 | destination_dir=...) 105 | val_set_path = download_simulated_subset(version='v1.5', volume='200hrs', subset_name='val', 106 | destination_dir=...) 107 | 108 | 109 | Returns: 110 | a string indicates the output directory path, or None if the download failed 111 | """ 112 | _LOG.info(f'Downloading simulated subset: {subset_name}, version: {version}, volume: {volume}') 113 | set_type = 'css-datasets' 114 | destination_dir = Path(destination_dir) 115 | if overwrite and destination_dir.exists(): 116 | shutil.rmtree(destination_dir) 117 | 118 | hf_subfolder = f'{set_type}/{version}/{volume}/{subset_name}' 119 | download_hf_dir(subfolder=hf_subfolder, local_dir=destination_dir) 120 | _LOG.info(f'Download completed: {subset_name}, version: {version}, volume: {volume}') 121 | return str(destination_dir) if destination_dir.exists() else None 122 | 123 | 124 | def download_models(destination_dir: str, 125 | set_type: str = 'css-models', 126 | version: Literal['conformer0.5', 'conformer1.0'] = 'conformer1.0', 127 | pattern: Optional[str] = None, overwrite: bool = False) -> Optional[str]: 128 | """ 129 | Download the models to the destination directory 130 | Args: 131 | destination_dir: path to destination directory to download the models to 132 | version: version of the models to download (conformer0.5 / conformer1.0), default: conformer1.0 133 | pattern: pattern to match the models to download. 134 | (e.g. 'mc' will download all notsofar baseline mc models). 135 | overwrite: whether to override the output file if it already exists 136 | (warning!: if true, will delete the entire destination_dir if it exists) 137 | Returns: 138 | a string indicates the output directory path, or None if the download failed 139 | """ 140 | _LOG.info(f'Downloading models: version: {version}, pattern: {pattern}') 141 | destination_dir = Path(destination_dir) 142 | models_subdir = f'{set_type}/notsofar/{version}' 143 | models_local_dir = destination_dir / models_subdir 144 | if overwrite and models_local_dir.exists(): 145 | shutil.rmtree(models_local_dir) 146 | 147 | download_hf_dir(subfolder=f'{models_subdir}{"/" + pattern if pattern else ""}', local_dir=destination_dir) 148 | _LOG.info(f'Download completed: models version {version}, pattern: {pattern}') 149 | return str(models_local_dir) if models_local_dir.exists() else None 150 | 151 | 152 | def main(): 153 | """ 154 | Usage example for downloading the NOTSOFAR dataset and models. 155 | """ 156 | logging.basicConfig(level=logging.INFO) 157 | 158 | with tempfile.TemporaryDirectory() as temp_dir: 159 | _LOG.info(f'Temp dir: {temp_dir}') 160 | _LOG.info('Downloading NOTSOFAR dataset and models...') 161 | 162 | _LOG.info('Downloading meeting subset') 163 | dev_set_dir = download_meeting_subset( 164 | subset_name='dev_set', version='240208.2_dev', # dev-set is without GT for now 165 | destination_dir=os.path.join(temp_dir, 'meeting_data')) 166 | _LOG.info(f'Dev set dir: {dev_set_dir}') 167 | 168 | _LOG.info('Downloading models') 169 | models_dir = download_models(destination_dir=os.path.join(temp_dir, 'models'), pattern='mc') 170 | _LOG.info(f'Models dir: {models_dir}') 171 | 172 | 173 | if __name__ == '__main__': 174 | main() 175 | -------------------------------------------------------------------------------- /utils/numpy_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def erode(arr: np.ndarray, iters: int): 5 | assert arr.ndim == 1 6 | arr_padded = np.pad(arr, iters, mode='constant', constant_values=1) 7 | return np.lib.stride_tricks.sliding_window_view(arr_padded, 2 * iters + 1).min(1) 8 | 9 | 10 | def dilate(arr: np.ndarray, iters: int): 11 | assert arr.ndim == 1 12 | arr_padded = np.pad(arr, iters, mode='constant', constant_values=0) 13 | return np.lib.stride_tricks.sliding_window_view(arr_padded, 2 * iters + 1).max(1) 14 | 15 | 16 | def test_morphology(): 17 | arr = np.array( [1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0], dtype=bool) 18 | eroded = erode(arr, 1) 19 | assert np.all(eroded == [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) 20 | 21 | dilated = dilate(arr, 1) 22 | assert np.all(dilated == [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]) 23 | 24 | 25 | if __name__ == "__main__": 26 | test_morphology() -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | """Plot CSS inference intermediate results for debug. See usage in css.py""" 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from utils.audio_utils import play_wav, write_wav 9 | 10 | 11 | def plot_stitched_masks(mask_stitched, activity_b, activity_final, cfg, title_str: Optional[str] = None, 12 | out_filename: Optional[str] = None, segment_frames=None, segment_size_sec=None): 13 | import matplotlib.pyplot as plt 14 | activity = mask_stitched.mean(dim=1) # [B, T, num_spks] 15 | total_plots = cfg.num_spks * 2 16 | time_frames = mask_stitched.size(2) # Assuming the number of time frames is the third dimension 17 | 18 | if segment_frames is not None or segment_size_sec is not None: 19 | assert segment_frames is not None and segment_size_sec is not None, \ 20 | 'Either both segment_frames and segment_size_sec must be provided or none!' 21 | frames_per_sec = int(segment_frames / segment_size_sec) 22 | else: 23 | frames_per_sec = None 24 | 25 | plt.figure(figsize=(15, 5 * total_plots)) 26 | for j in range(cfg.num_spks): 27 | # Plot for mask_stitched 28 | plt.subplot(total_plots, 1, 2 * j + 1) 29 | plt.imshow(mask_stitched[0, :, :, j], aspect='auto', origin='lower') 30 | # plt.colorbar() 31 | plt.title(f"Speaker {j + 1} Mask") 32 | # plt.xlabel("Time Frames") 33 | plt.ylabel("Frequency Bins") 34 | plt.xlim(0, time_frames - 1) # Set x-axis limits 35 | if frames_per_sec is not None: 36 | plt.xticks(range(0, time_frames, frames_per_sec//2), 37 | list(map(lambda x: x / frames_per_sec, range(0, time_frames, frames_per_sec//2)))) 38 | # Plot for activity 39 | plt.subplot(total_plots, 1, 2 * j + 2) 40 | plt.plot(activity[0, :, j], label='mean mask') 41 | plt.plot(activity_b[:, j], label=f'thresh={cfg.activity_th}') 42 | plt.plot(activity_final[0, :, j], 43 | label=f'dilate({cfg.activity_dilation_sec})->erode({cfg.activity_erosion_sec})') 44 | plt.title(f"Speaker {j + 1} Activity") 45 | # plt.xlabel("Time Frames") 46 | plt.ylabel("Average Activity") 47 | plt.xlim(0, time_frames - 1) # Set x-axis limits to be the same as the mask_stitched plot 48 | plt.ylim(0, 1.05) 49 | if frames_per_sec is not None: 50 | plt.xticks(range(0, time_frames, frames_per_sec//2), 51 | list(map(lambda x: x / frames_per_sec, range(0, time_frames, frames_per_sec//2)))) 52 | plt.legend(loc='best') # Add a legend 53 | plt.suptitle(title_str or 'Speaker Masks and Activities') 54 | 55 | if out_filename is None: 56 | plt.show() 57 | else: 58 | plt.savefig(out_filename, bbox_inches='tight') 59 | 60 | 61 | def plot_left_right_stitch(separator, left_input, right_input, right_perm, overlap_frames, 62 | cfg, stft_seg_to_play: Optional[torch.Tensor]=None, fs: Optional[int]=None): 63 | if stft_seg_to_play is not None: 64 | separator.cpu() 65 | wav = separator.istft(stft_seg_to_play).cpu().numpy() 66 | play_wav(wav.squeeze(), fs, volume_factor=5.) 67 | 68 | left = left_input # overlapping part - [:, :, -overlap_frames:] 69 | right = right_input # overlapping part - [:, :, :overlap_frames] 70 | import matplotlib.pyplot as plt 71 | num_spks = cfg.num_spks 72 | plt.figure(figsize=(15, 5 * num_spks)) 73 | for j in range(num_spks): 74 | plt.subplot(num_spks, 1, j + 1) 75 | plt.imshow(left[0, :, :, j], aspect='auto', origin='lower') 76 | plt.axvline(x=left.shape[2] - overlap_frames, color='red', linestyle='--') 77 | plt.colorbar() 78 | plt.title(f"Speaker {j + 1} Mask") 79 | plt.xlabel("Time Frames") 80 | plt.ylabel("Frequency Bins") 81 | plt.suptitle('left') 82 | plt.show() 83 | plt.figure(figsize=(15, 5 * num_spks)) 84 | for j in range(num_spks): 85 | plt.subplot(num_spks, 1, j + 1) 86 | plt.imshow(right[0, :, :, right_perm[0][j]], aspect='auto', origin='lower') 87 | plt.axvline(x=overlap_frames, color='red', linestyle='--') 88 | plt.colorbar() 89 | plt.title(f"Speaker {j + 1} Mask") 90 | plt.xlabel("Time Frames") 91 | plt.ylabel("Frequency Bins") 92 | plt.suptitle('right') 93 | plt.show() 94 | 95 | 96 | def plot_separation_methods(stft_seg_device_chref, masks, mvdr_responses, separator, cfg, plots): 97 | """Plot various masking methods for multi-channel segment, and writes them as wav files. 98 | plots arg controls what to plot. 99 | 100 | For full plot: 101 | plots = ['mvdr', 'masked_mvdr', 'spk_masks', 'masked_ref_ch', 'mixture'] 102 | """ 103 | import matplotlib.pyplot as plt 104 | import librosa 105 | plots_ordered = [] 106 | num_spks = cfg.num_spks 107 | fig, axs = plt.subplots(num_spks, len(plots), figsize=(30, 5 * num_spks)) 108 | masked_ref_ch = stft_seg_device_chref.unsqueeze(-1) * masks['spk_masks'] 109 | masked_mvdr = mvdr_responses * masks['spk_masks'] # note, no floor 110 | col_ind = -1 111 | if 'mvdr' in plots: 112 | plots_ordered.append('mvdr') 113 | col_ind += 1 114 | for j in range(num_spks): 115 | ax = axs[j, col_ind] 116 | img = librosa.display.specshow( 117 | librosa.amplitude_to_db(mvdr_responses[0, :, :, j].abs().cpu(), ref=np.max), 118 | y_axis='linear', x_axis='time', ax=ax, sr=16000) 119 | ax.set_title(f'Speaker {j + 1} Spectrogram') 120 | plt.colorbar(img, ax=ax, format="%+2.0f dB") 121 | ax.set_xlabel("Time Frames") 122 | ax.set_ylabel("Frequency Bins") 123 | if 'masked_mvdr' in plots: 124 | plots_ordered.append('masked_mvdr') 125 | col_ind += 1 126 | for j in range(num_spks): 127 | ax = axs[j, col_ind] 128 | img = librosa.display.specshow( 129 | librosa.amplitude_to_db(masked_mvdr[0, :, :, j].abs().cpu(), ref=np.max), 130 | y_axis='linear', x_axis='time', ax=ax, sr=16000) 131 | ax.set_title(f'Speaker {j + 1} Spectrogram') 132 | plt.colorbar(img, ax=ax, format="%+2.0f dB") 133 | ax.set_xlabel("Time Frames") 134 | ax.set_ylabel("Frequency Bins") 135 | if 'masked_ref_ch' in plots: 136 | plots_ordered.append('masked_ref_ch') 137 | col_ind += 1 138 | for j in range(num_spks): 139 | ax = axs[j, col_ind] 140 | img = librosa.display.specshow( 141 | librosa.amplitude_to_db(masked_ref_ch[0, :, :, j].abs().cpu(), ref=np.max), 142 | y_axis='linear', x_axis='time', ax=ax, sr=16000) 143 | ax.set_title(f'Speaker {j + 1} Spectrogram') 144 | plt.colorbar(img, ax=ax, format="%+2.0f dB") 145 | ax.set_xlabel("Time Frames") 146 | ax.set_ylabel("Frequency Bins") 147 | if 'spk_masks' in plots: 148 | plots_ordered.append('spk_masks') 149 | col_ind += 1 150 | for j in range(num_spks): 151 | ax = axs[j, col_ind] 152 | img = ax.imshow(masks['spk_masks'][0, :, :, j].cpu(), aspect='auto', origin='lower', vmin=0, 153 | vmax=1) 154 | plt.colorbar(img, ax=ax) 155 | ax.set_xlabel("Time Frames") 156 | ax.set_ylabel("Frequency Bins") 157 | if 'mixture' in plots: 158 | plots_ordered.append('mixture') 159 | col_ind += 1 160 | # plot mixture ch0 161 | ax = axs[0, col_ind] 162 | img_right = librosa.display.specshow( 163 | librosa.amplitude_to_db(stft_seg_device_chref[0, :, :].abs().cpu(), ref=np.max), 164 | y_axis='linear', x_axis='time', ax=ax) 165 | plt.colorbar(img_right, ax=ax, format="%+2.0f dB") 166 | ax.set_xlabel("Time Frames") 167 | ax.set_ylabel("Frequency Bins") 168 | 169 | # plot noisemask 170 | ax = axs[1, col_ind] 171 | img = ax.imshow(masks['noise_masks'][0, :, :, 0].cpu(), aspect='auto', origin='lower', vmin=0, vmax=1) 172 | plt.colorbar(img, ax=ax) 173 | ax.set_xlabel("Time Frames") 174 | ax.set_ylabel("Frequency Bins") 175 | 176 | plt.suptitle(' | '.join(plots_ordered)) 177 | plt.tight_layout() 178 | plt.show() 179 | 180 | istft = lambda x: separator.istft(x).cpu().numpy()[0] 181 | # x: [B, num_spks, Nsamples] 182 | out_dir = Path('artifacts/analysis/separated_seg') 183 | write_wav(out_dir / 'input_ref_ch.wav', samps=istft(stft_seg_device_chref), sr=16000) 184 | for j in range(num_spks): 185 | write_wav(out_dir / f'masked_ref_ch{j}.wav', samps=istft(masked_ref_ch[..., j]), sr=16000) 186 | write_wav(out_dir / f'mvdr_{j}.wav', samps=istft(mvdr_responses[..., j]), sr=16000) 187 | write_wav(out_dir / f'masked_mvdr_{j}.wav', samps=istft(masked_mvdr[..., j]), sr=16000) -------------------------------------------------------------------------------- /utils/scoring.py: -------------------------------------------------------------------------------- 1 | import decimal 2 | from functools import partial 3 | from pathlib import Path 4 | from dataclasses import dataclass 5 | from typing import List, Dict, Callable 6 | import os 7 | 8 | import pandas as pd 9 | import meeteval 10 | import meeteval.io.chime7 11 | from meeteval.io.seglst import SegLstSegment 12 | from meeteval.viz.visualize import AlignmentVisualization 13 | 14 | from utils.logging_def import get_logger 15 | from utils.text_norm_whisper_like import get_txt_norm 16 | 17 | _LOG = get_logger('wer') 18 | 19 | 20 | @dataclass 21 | class ScoringCfg: 22 | # If True, saves reference - hypothesis visualizations (self-contained html) 23 | save_visualizations: bool = False 24 | 25 | 26 | def df_to_seglst(df): 27 | return meeteval.io.SegLST([ 28 | SegLstSegment( 29 | session_id=row.session_id, 30 | start_time=decimal.Decimal(row.start_time), 31 | end_time=decimal.Decimal(row.end_time), 32 | words=row.text, 33 | speaker=row.speaker_id, 34 | ) 35 | for row in df.itertuples() 36 | ]) 37 | 38 | 39 | def normalize_segment(segment: SegLstSegment, tn): 40 | words = segment["words"] 41 | words = tn(words) 42 | segment["words"] = words 43 | return segment 44 | 45 | 46 | def calc_wer(out_dir: str, 47 | tcp_wer_hyp_json: str | List[Dict], 48 | tcorc_wer_hyp_json: str | List[Dict], 49 | gt_utt_df: pd.DataFrame, tn: str | Callable = 'chime8', 50 | collar: float = 5, save_visualizations: bool = False) -> pd.DataFrame: 51 | """ 52 | Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error 53 | information to .json. 54 | Text normalization is applied to both hypothesis and reference. 55 | 56 | Args: 57 | out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df). 58 | tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure. 59 | tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure. 60 | gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files. 61 | see load_data() function. 62 | tn: text normalizer 63 | collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference. 64 | save_visualizations: if True, save html visualizations of alignment between hyp and ref. 65 | Returns: 66 | wer_df: pd.DataFrame with columns - 67 | 'session_id' - same as in hypothesis files 68 | 'tcp_wer': tcpWER 69 | 'tcorc_wer': tcorcWER 70 | ... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code. 71 | """ 72 | # json to SegLST structure (Segment-wise Long-form Speech Transcription annotation) 73 | to_seglst = lambda x: meeteval.io.chime7.json_to_stm(x, None).to_seglst() if isinstance(x, list) \ 74 | else meeteval.io.load(Path(x)) 75 | tcp_hyp_seglst = to_seglst(tcp_wer_hyp_json) 76 | tcorc_hyp_seglst = to_seglst(tcorc_wer_hyp_json) 77 | 78 | # map session_id to meetind_id and join with gt_utt_df to include GT utterances for each session. 79 | # since every meeting contributes several sessions, a meeting's GT will be repeated for every session. 80 | sess2meet_id = tcp_hyp_seglst.groupby('session_id').keys() 81 | sess2meet_id = pd.DataFrame(sess2meet_id, columns=['session_id']) 82 | sess2meet_id['meeting_id'] = sess2meet_id['session_id'].str.extract(r'(MTG_\d+)') 83 | joined_df = pd.merge(sess2meet_id, gt_utt_df, on='meeting_id', how='left') 84 | ref_seglst = df_to_seglst(joined_df) 85 | 86 | if isinstance(tn, str): 87 | tn = get_txt_norm(tn) 88 | # normalization should be idempotent so a second normalization will not change the result 89 | tcp_hyp_seglst = tcp_hyp_seglst.map(partial(normalize_segment, tn=tn)) 90 | tcorc_hyp_seglst = tcorc_hyp_seglst.map(partial(normalize_segment, tn=tn)) 91 | ref_seglst = ref_seglst.map(partial(normalize_segment, tn=tn)) 92 | 93 | ref_file_path = Path(out_dir) / 'ref.json' 94 | ref_file_path.parent.mkdir(parents=True, exist_ok=True) 95 | ref_seglst.dump(ref_file_path) 96 | 97 | def save_wer_visualization(ref, hyp): 98 | ref = ref.groupby('session_id') 99 | hyp = hyp.groupby('session_id') 100 | assert len(ref) == 1 and len(hyp) == 1, 'expecting one session for visualization' 101 | assert list(ref.keys())[0] == list(hyp.keys())[0] 102 | 103 | meeting_name = list(ref.keys())[0] 104 | av = AlignmentVisualization(ref[meeting_name], hyp[meeting_name], alignment='tcp') 105 | # Create standalone HTML file 106 | av.dump(os.path.join(out_dir, 'viz.html')) 107 | 108 | def calc_session_tcp_wer(ref, hyp): 109 | res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar) 110 | 111 | res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') 112 | keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', 113 | 'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment'] 114 | return (res_df[['session_id'] + keys] 115 | .rename(columns={k: 'tcp_' + k for k in keys}) 116 | .rename(columns={'tcp_error_rate': 'tcp_wer'})) 117 | 118 | def calc_session_tcorc_wer(ref, hyp): 119 | res = meeteval.wer.tcorcwer(reference=ref, hypothesis=hyp, collar=collar) 120 | 121 | res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') 122 | keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', 'assignment'] 123 | return (res_df[['session_id'] + keys] 124 | .rename(columns={k: 'tcorc_' + k for k in keys}) 125 | .rename(columns={'tcorc_error_rate': 'tcorc_wer'})) 126 | 127 | tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst) 128 | tcorc_wer_res = calc_session_tcorc_wer(ref_seglst, tcorc_hyp_seglst) 129 | if save_visualizations: 130 | save_wer_visualization(ref_seglst, tcp_hyp_seglst) 131 | 132 | wer_df = pd.concat([tcp_wer_res, tcorc_wer_res.drop(columns='session_id')], axis=1) 133 | 134 | if isinstance(tcp_wer_hyp_json, str | Path): 135 | wer_df['tcp_wer_hyp_json'] = tcp_wer_hyp_json 136 | if isinstance(tcorc_wer_hyp_json, str | Path): 137 | wer_df['tcorc_wer_hyp_json'] = tcorc_wer_hyp_json 138 | 139 | _LOG.info('Done calculating WER') 140 | _LOG.info(f"\n{wer_df[['session_id', 'tcp_wer', 'tcorc_wer']]}") 141 | 142 | return wer_df 143 | 144 | 145 | def write_submission_jsons(out_dir: str, hyp_jsons_df: pd.DataFrame): 146 | """ 147 | Merges the per-session jsons in hyp_jsons_df and writes them under the appropriate track folder 148 | in out_dir. 149 | The resulting jsons can be used for submission. 150 | """ 151 | # close-talk is not supposed to be used for scoring 152 | hyp_jsons_df = hyp_jsons_df[~hyp_jsons_df.is_close_talk] 153 | 154 | def write_json(files, file_name, is_mc): 155 | seglst = [] 156 | for f in files: 157 | data = meeteval.io.load(f) 158 | seglst.extend(data) 159 | seglst = meeteval.io.SegLST(seglst) 160 | track = 'multichannel' if is_mc else 'singlechannel' 161 | filepath = Path(out_dir) / 'wer' / track / file_name 162 | seglst.dump(filepath) 163 | _LOG.info(f'Wrote hypothesis transcript for submission: {filepath}') 164 | 165 | mc_hyps = hyp_jsons_df[hyp_jsons_df.is_mc] 166 | sc_hyps = hyp_jsons_df[~hyp_jsons_df.is_mc] 167 | 168 | if len(mc_hyps) > 0: 169 | write_json(mc_hyps.tcp_wer_hyp_json, 'tcp_wer_hyp.json', is_mc=True) 170 | write_json(mc_hyps.tcorc_wer_hyp_json, 'tc_orc_wer_hyp.json', is_mc=True) 171 | 172 | if len(sc_hyps) > 0: 173 | write_json(sc_hyps.tcp_wer_hyp_json, 'tcp_wer_hyp.json', is_mc=False) 174 | write_json(sc_hyps.tcorc_wer_hyp_json, 'tc_orc_wer_hyp.json', is_mc=False) 175 | 176 | -------------------------------------------------------------------------------- /utils/text_norm_whisper_like/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/text_norm_whisper_like/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | NOTSOFAR adopts the same text normalizer as the CHiME-8 DASR track. 3 | This code is aligned with the CHiME-8 repo: 4 | https://github.com/chimechallenge/chime-utils/tree/main/chime_utils/text_norm 5 | """ 6 | 7 | from .basic import BasicTextNormalizer as BasicTextNormalizer 8 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 9 | 10 | 11 | def get_txt_norm(txt_norm): 12 | assert txt_norm in ["chime8", None] 13 | if txt_norm is None: 14 | return None 15 | elif txt_norm == "chime8": 16 | return EnglishTextNormalizer() 17 | else: 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /utils/text_norm_whisper_like/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | ( 34 | c 35 | if c in keep 36 | else ( 37 | ADDITIONAL_DIACRITICS[c] 38 | if c in ADDITIONAL_DIACRITICS 39 | else ( 40 | "" 41 | if unicodedata.category(c) == "Mn" 42 | else " " 43 | if unicodedata.category(c)[0] in "MSP" 44 | else c 45 | ) 46 | ) 47 | ) 48 | for c in unicodedata.normalize("NFKD", s) 49 | ) 50 | 51 | 52 | def remove_symbols(s: str): 53 | """ 54 | Replace any other markers, symbols, 55 | punctuations with a space, keeping diacritics 56 | """ 57 | return "".join( 58 | " " if unicodedata.category(c)[0] in "MSP" else c 59 | for c in unicodedata.normalize("NFKC", s) 60 | ) 61 | 62 | 63 | class BasicTextNormalizer: 64 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 65 | self.clean = ( 66 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 67 | ) 68 | self.split_letters = split_letters 69 | 70 | def __call__(self, s: str): 71 | s = s.lower() 72 | # remove words between brackets 73 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) 74 | # remove words between parenthesis 75 | s = re.sub(r"\(([^)]+?)\)", "", s) 76 | s = self.clean(s).lower() 77 | 78 | if self.split_letters: 79 | s = " ".join(regex.findall(r"\X", s, regex.U)) 80 | 81 | s = re.sub( 82 | r"\s+", " ", s 83 | ) # replace any successive whitespace characters with a space 84 | 85 | return s 86 | -------------------------------------------------------------------------------- /utils/text_norm_whisper_like/pre_english.json: -------------------------------------------------------------------------------- 1 | { 2 | "shan't": "shall not", 3 | "han't": "has not", 4 | "ain't": "ain not" 5 | } -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | import pandas as pd 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.distributed as dist 8 | 9 | 10 | def is_dist_env_available(): 11 | return os.environ.get('WORLD_SIZE', None) is not None 12 | 13 | 14 | def is_dist_initialized(): 15 | """ 16 | Returns True if distributed mode has been initiated (torch.distributed.init_process_group) 17 | """ 18 | return dist.is_available() and dist.is_initialized() 19 | 20 | 21 | def get_world_size() -> int: 22 | return dist.get_world_size() if is_dist_initialized() else 1 23 | 24 | 25 | def get_rank(): 26 | return dist.get_rank() if is_dist_initialized() else 0 27 | 28 | 29 | def is_zero_rank(): 30 | return get_rank() == 0 31 | 32 | 33 | def barrier(): 34 | if is_dist_initialized(): 35 | dist.barrier() 36 | 37 | 38 | def get_device_name(): 39 | if is_dist_initialized(): 40 | # when the number of nodes is 1, we can use only get_rank() to get the device_id 41 | # but when the number of nodes is greater than 1, the device_id can be calculated by: 42 | device_id = get_rank() % torch.cuda.device_count() 43 | return f'cuda:{device_id}' 44 | 45 | return "cuda" if torch.cuda.is_available() else "cpu" 46 | 47 | 48 | class DDPRowIterator: 49 | """ A class that wraps a DataFrame, such that the returned DataFrame row number is divided by the world size 50 | (i.e. the number of processes created by the DDP). The padded rows are filled with the row at the given 51 | dummy_row_idx field. 52 | This is useful for distributed inference, where we want to distribute the data across all processes, such that 53 | all processes are working on different rows at the same time, while no process is idle (DDP assumption). 54 | The next() method returns a tuple of (row, row_idx, is_dummy) where is_dummy is True if the row is a padded row. 55 | Each process will iterate over the rows that are assigned to it, and then stop when the rows are exhausted. 56 | 57 | Args: 58 | df (pd.DataFrame): the DataFrame to iterate over 59 | """ 60 | 61 | def __init__(self, df: pd.DataFrame): 62 | self.df = df 63 | self.world_size = get_world_size() 64 | self.current_process_idx = get_rank() 65 | self.rows_per_chunk = len(df) // self.world_size 66 | self.remainder = len(df) % self.world_size 67 | self.current_row_idx = 0 68 | self.dummy_row_idx = self.current_process_idx 69 | assert self.dummy_row_idx < len(self.df), f'{self.dummy_row_idx=} must be less than {len(self.df)=}' 70 | 71 | @property 72 | def _padded_df_len(self): 73 | return len(self.df) + ((self.world_size - self.remainder) if self.remainder > 0 else 0) 74 | 75 | def __len__(self): 76 | return int(self._padded_df_len / self.world_size) 77 | 78 | def __iter__(self): 79 | self.current_row_idx = self.current_process_idx 80 | return self 81 | 82 | def __next__(self): 83 | if self.current_row_idx >= self._padded_df_len: 84 | # Wait for all processes to finish processing 85 | barrier() 86 | raise StopIteration 87 | 88 | row_idx = self.current_row_idx 89 | 90 | if row_idx < len(self.df): 91 | is_dummy = False 92 | row = self.df.iloc[row_idx] 93 | else: 94 | # if we are here, we are padding the DataFrame (self.current_row_idx >= len(self.df)) 95 | is_dummy = True 96 | row = self.df.iloc[self.dummy_row_idx] 97 | 98 | self.current_row_idx += self.world_size 99 | return row, row_idx, is_dummy 100 | 101 | 102 | def initialize_ddp(logger): 103 | """ Process group initialization for distributed inference """ 104 | if is_dist_env_available(): 105 | rank = int(os.environ['RANK']) 106 | world_size = int(os.environ['WORLD_SIZE']) 107 | dist.init_process_group('nccl', rank=rank, world_size=world_size) 108 | logger.info(f'Distributed: {get_rank()=}, {get_world_size()=}') 109 | # NOTE! must call set_device or allocations go to GPU 0 disproportionally, causing CUDA OOM. 110 | torch.cuda.set_device(torch.device(get_device_name())) 111 | dist.barrier() 112 | 113 | return get_device_name() 114 | 115 | 116 | def get_max_value(value: int) -> int: 117 | """ Returns the maximum value from all processes """ 118 | if not is_dist_initialized(): 119 | return value 120 | 121 | tensor = torch.tensor(value).cuda() 122 | dist.all_reduce(tensor, op=dist.ReduceOp.MAX) 123 | return int(tensor.item()) 124 | 125 | 126 | def move_to(obj: Any, device: torch.device, numpy: bool=False) -> Any: 127 | """recursively visit a tuple/list/dict structure (can extend to more types if required)""" 128 | # pylint: disable=unidiomatic-typecheck # explicitly differentiate tuple from NamedTuple 129 | if type(obj) is tuple or isinstance(obj, list): # modify sequence by rebuilding it 130 | # noinspection PyArgumentList 131 | return type(obj)(move_to(x, device, numpy) for x in obj) 132 | if hasattr(obj, 'to'): 133 | # noinspection PyCallingNonCallable 134 | obj = obj.to(device) 135 | # convert floating point types to float32. 136 | obj = obj.float() if 'float' in str(obj.dtype) else obj 137 | return obj.numpy() if numpy else obj 138 | if isinstance(obj, dict): 139 | return type(obj)(**{k: move_to(v, device, numpy) for k,v in obj.items()}) 140 | if isinstance(obj, tuple): # NamedTuple case 141 | # noinspection PyArgumentList 142 | return type(obj)(*(move_to(x, device, numpy) for x in obj)) 143 | return obj 144 | 145 | 146 | def catch_unused_params(model: nn.Module): 147 | """ 148 | Throws error and reports unused parameters in case there are any. 149 | Useful for catching such parameters to prevent torch.nn.parallel.DistributedDataParallel from crashing. 150 | 151 | Note: Call this after backward pass. 152 | """ 153 | unused = [name for name, param in model.named_parameters() 154 | if param.grad is None and param.requires_grad] 155 | unused_str = "\n" + "\n".join(unused) 156 | assert len(unused) == 0, f'Found unused parameters: {unused_str}' 157 | 158 | 159 | def reduce_dict_to_rank0(input_dict: Dict, average: bool): 160 | """ 161 | Args: 162 | input_dict (dict): all the values will be reduced 163 | average (bool): whether to do average or sum 164 | Reduce the values in the dictionary from all processes so that process with rank 165 | 0 has the averaged results. Returns a dict with the same fields as 166 | input_dict, after reduction. 167 | """ 168 | world_size = get_world_size() 169 | if world_size < 2: 170 | return input_dict 171 | with torch.no_grad(): 172 | names = [] 173 | values = [] 174 | # sort the keys so that they are consistent across processes 175 | for k in sorted(input_dict.keys()): 176 | names.append(k) 177 | values.append(input_dict[k]) 178 | values = torch.stack(values, dim=0) 179 | dist.reduce(values, dst=0) 180 | if dist.get_rank() == 0 and average: 181 | # only main process gets accumulated, so only divide by 182 | # world_size in this case 183 | values /= world_size 184 | reduced_dict = {k: v for k, v in zip(names, values)} 185 | return reduced_dict 186 | --------------------------------------------------------------------------------