├── gss ├── utils │ ├── __init__.py │ ├── numpy_utils.py │ └── data_utils.py ├── __init__.py ├── wpe │ ├── __init__.py │ └── wpe.py ├── bin │ ├── __init__.py │ ├── modes │ │ ├── __init__.py │ │ ├── cli_base.py │ │ ├── utils.py │ │ └── enhance.py │ └── gss.py ├── beamformer │ ├── __init__.py │ ├── souden_mvdr.py │ ├── utils.py │ └── beamform.py ├── core │ ├── __init__.py │ ├── wpe.py │ ├── beamformer.py │ ├── activity.py │ ├── gss.py │ ├── stft_module.py │ └── enhancer.py └── cacgmm │ ├── __init__.py │ ├── cacgmm.py │ ├── utils.py │ ├── cacg.py │ ├── cacg_trainer.py │ └── cacgmm_trainer.py ├── recipes ├── ami │ ├── conf │ ├── path.sh │ ├── utils │ ├── run_train.sh │ └── run.sh ├── chime6 │ ├── conf │ ├── utils │ ├── path.sh │ ├── env.sh │ └── run.sh ├── dipco │ ├── conf │ ├── utils │ ├── path.sh │ └── run.sh ├── alimeeting │ ├── conf │ ├── utils │ ├── path.sh │ └── run.sh └── libricss │ ├── path.sh │ ├── conf │ └── gpu.conf │ ├── run.sh │ └── utils │ ├── acquire-gpu │ ├── parse_options.sh │ ├── queue.pl │ └── queue-ackgpu.pl ├── test.pstats ├── .gitmodules ├── CITATION.cff ├── .pre-commit-config.yaml ├── LICENSE ├── setup.py ├── .gitignore ├── .github └── workflows │ └── style_check.yml └── README.md /gss/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /recipes/ami/conf: -------------------------------------------------------------------------------- 1 | ../libricss/conf -------------------------------------------------------------------------------- /recipes/ami/path.sh: -------------------------------------------------------------------------------- 1 | ../libricss/path.sh -------------------------------------------------------------------------------- /recipes/ami/utils: -------------------------------------------------------------------------------- 1 | ../libricss/utils -------------------------------------------------------------------------------- /recipes/chime6/conf: -------------------------------------------------------------------------------- 1 | ../libricss/conf -------------------------------------------------------------------------------- /recipes/chime6/utils: -------------------------------------------------------------------------------- 1 | ../libricss/utils -------------------------------------------------------------------------------- /recipes/dipco/conf: -------------------------------------------------------------------------------- 1 | ../libricss/conf/ -------------------------------------------------------------------------------- /recipes/dipco/utils: -------------------------------------------------------------------------------- 1 | ../libricss/utils -------------------------------------------------------------------------------- /gss/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /recipes/alimeeting/conf: -------------------------------------------------------------------------------- 1 | ../libricss/conf -------------------------------------------------------------------------------- /recipes/alimeeting/utils: -------------------------------------------------------------------------------- 1 | ../libricss/utils -------------------------------------------------------------------------------- /recipes/chime6/path.sh: -------------------------------------------------------------------------------- 1 | ../libricss/path.sh -------------------------------------------------------------------------------- /recipes/dipco/path.sh: -------------------------------------------------------------------------------- 1 | ../libricss/path.sh -------------------------------------------------------------------------------- /gss/wpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .wpe import wpe 2 | -------------------------------------------------------------------------------- /recipes/alimeeting/path.sh: -------------------------------------------------------------------------------- 1 | ../libricss/path.sh -------------------------------------------------------------------------------- /gss/bin/__init__.py: -------------------------------------------------------------------------------- 1 | from gss.bin.modes import * 2 | -------------------------------------------------------------------------------- /gss/beamformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .souden_mvdr import beamform_mvdr 2 | -------------------------------------------------------------------------------- /test.pstats: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desh2608/gss/HEAD/test.pstats -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pb_bss"] 2 | path = pb_bss 3 | url = https://github.com/fgnt/pb_bss.git 4 | -------------------------------------------------------------------------------- /gss/bin/modes/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli_base import * 2 | from .enhance import * 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /gss/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .activity import Activity 2 | from .beamformer import Beamformer 3 | from .gss import GSS 4 | from .stft_module import istft, stft 5 | from .wpe import WPE 6 | -------------------------------------------------------------------------------- /gss/cacgmm/__init__.py: -------------------------------------------------------------------------------- 1 | from .cacg import ComplexAngularCentralGaussian 2 | from .cacg_trainer import ComplexAngularCentralGaussianTrainer 3 | from .cacgmm import CACGMM 4 | from .cacgmm_trainer import CACGMMTrainer 5 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 0.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Raj 5 | given-names: Desh 6 | orcid: http://orcid.org/0000-0002-5038-9400 7 | title: "GPU-accelerated Guided Source Separation" 8 | version: 0.2.0 9 | date-released: 2022-04-07 10 | -------------------------------------------------------------------------------- /recipes/libricss/path.sh: -------------------------------------------------------------------------------- 1 | if [ -f env.sh ]; then 2 | source env.sh 3 | fi 4 | 5 | # Print immediately 6 | export PYTHONUNBUFFERED=1 7 | 8 | export PATH=${PATH}:`pwd`/utils 9 | 10 | # Activate environment 11 | . /home/draj/anaconda3/etc/profile.d/conda.sh && conda deactivate && conda activate gss 12 | 13 | export LC_ALL=C 14 | -------------------------------------------------------------------------------- /gss/bin/gss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Use this script like: 4 | $ gss enhance --help 5 | $ gss utils --help 6 | """ 7 | 8 | # Note: we import all the CLI modes here so they get auto-registered 9 | # in the main CLI entry-point. Then, setuptools is told to 10 | # invoke the "cli()" method from this script. 11 | from gss.bin.modes import * 12 | -------------------------------------------------------------------------------- /recipes/chime6/env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # enable module support 4 | source /etc/profile.d/modules.sh 5 | module load gcc/7.2.0 || exit 1 6 | module use /home/hltcoe/draj/modulefiles || exit 1 7 | module load cuda || exit 1 # loads CUDA 10.2 8 | module load cudnn || exit 1 # loads cuDNN 8.0.2 9 | module load intel/mkl/64/2019/5.281 || exit 1 10 | module load nccl || exit 1 11 | -------------------------------------------------------------------------------- /gss/bin/modes/cli_base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import click 4 | 5 | 6 | @click.group() 7 | def cli(): 8 | """ 9 | The shell entry point to `gss`, a tool and a library for GSS-based front-end enhancement. 10 | """ 11 | logging.basicConfig( 12 | format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", 13 | level=logging.INFO, 14 | ) 15 | -------------------------------------------------------------------------------- /recipes/libricss/conf/gpu.conf: -------------------------------------------------------------------------------- 1 | # Default configuration 2 | command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* 3 | option mem=* -l mem_free=$0,ram_free=$0 4 | option mem=0 # Do not add anything to qsub_opts 5 | option num_threads=* -pe smp $0 6 | option num_threads=1 # Do not add anything to qsub_opts 7 | option max_jobs_run=* -tc $0 8 | default gpu=0 9 | option gpu=0 10 | option gpu=* -l 'hostname=c*&!c21*,gpu=$0,num_proc=1,mem_free=12G,h_rt=600:00:00' -q g.q 11 | -------------------------------------------------------------------------------- /gss/core/wpe.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from gss.wpe import wpe 4 | 5 | 6 | @dataclass 7 | class WPE: 8 | taps: int 9 | delay: int 10 | iterations: int 11 | psd_context: int 12 | 13 | def __call__(self, Obs): 14 | Obs = wpe( 15 | Obs.transpose(2, 0, 1), 16 | taps=self.taps, 17 | delay=self.delay, 18 | iterations=self.iterations, 19 | psd_context=self.psd_context, 20 | ).transpose(1, 2, 0) 21 | 22 | return Obs 23 | -------------------------------------------------------------------------------- /gss/core/beamformer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from gss.beamformer import beamform_mvdr 4 | 5 | 6 | @dataclass 7 | class Beamformer: 8 | postfilter: str 9 | 10 | def __call__(self, Obs, target_mask, distortion_mask): 11 | X_hat = beamform_mvdr( 12 | Y=Obs, X_mask=target_mask, N_mask=distortion_mask, ban=True 13 | ) 14 | 15 | if self.postfilter is None: 16 | pass 17 | elif self.postfilter == "mask_mul": 18 | X_hat = X_hat * target_mask 19 | else: 20 | raise NotImplementedError(self.postfilter) 21 | 22 | return X_hat 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.2.0 4 | hooks: 5 | - id: check-executables-have-shebangs 6 | - id: end-of-file-fixer 7 | - id: mixed-line-ending 8 | - id: trailing-whitespace 9 | 10 | - repo: https://github.com/PyCQA/flake8 11 | rev: 4.0.1 12 | hooks: 13 | - id: flake8 14 | args: ['--select=E9,F63,F7,F82'] 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | args: [--profile=black] 21 | 22 | - repo: https://github.com/psf/black 23 | rev: 22.3.0 24 | hooks: 25 | - id: black 26 | additional_dependencies: ['click==8.0.1'] 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Department of Communications Engineering University of Paderborn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /gss/core/activity.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | from lhotse import CutSet 5 | 6 | 7 | @dataclass # (hash=True) 8 | class Activity: 9 | garbage_class: bool = False 10 | cuts: "CutSet" = None 11 | 12 | def __post_init__(self): 13 | self.activity = {} 14 | self.speaker_to_idx_map = {} 15 | for cut in self.cuts: 16 | self.speaker_to_idx_map[cut.recording_id] = { 17 | spk: idx 18 | for idx, spk in enumerate( 19 | sorted(set(s.speaker for s in cut.supervisions)) 20 | ) 21 | } 22 | self.supervisions_index = self.cuts.index_supervisions() 23 | 24 | def get_activity(self, session_id, start_time, duration): 25 | cut = self.cuts[session_id].truncate( 26 | offset=start_time, 27 | duration=duration, 28 | _supervisions_index=self.supervisions_index, 29 | ) 30 | activity_mask = cut.speakers_audio_mask( 31 | speaker_to_idx_map=self.speaker_to_idx_map[session_id] 32 | ) 33 | if self.garbage_class is False: 34 | activity_mask = np.r_[activity_mask, [np.zeros_like(activity_mask[0])]] 35 | elif self.garbage_class is True: 36 | activity_mask = np.r_[activity_mask, [np.ones_like(activity_mask[0])]] 37 | return activity_mask, self.speaker_to_idx_map[session_id] 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # To use a consistent encoding 2 | from codecs import open 3 | from os import path 4 | 5 | import numpy 6 | from setuptools import find_packages, setup 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # First check if cupy is installed 11 | try: 12 | import cupy 13 | except ImportError: 14 | raise RuntimeError( 15 | "CuPy is not available. Please install it manually: " 16 | "https://docs.cupy.dev/en/stable/install.html" 17 | ) 18 | 19 | # Get the long description from the README file 20 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 21 | long_description = f.read() 22 | 23 | dev_requires = [ 24 | "flake8==5.0.4", 25 | "black==22.3.0", 26 | "isort==5.10.1", 27 | "pre-commit>=2.17.0,<=2.19.0", 28 | ] 29 | 30 | setup( 31 | name="gss", 32 | version="0.6.1", 33 | description="GPU-accelerated Guided Source Separation", 34 | long_description=long_description, 35 | long_description_content_type="text/markdown", 36 | url="https://github.com/desh2608/gss", 37 | author="Desh Raj", 38 | author_email="r.desh26@gmail.com", 39 | keywords="speech enhancement gss", # Optional 40 | packages=find_packages(exclude=["contrib", "docs", "tests"]), # Required 41 | install_requires=[ 42 | "cached_property", 43 | "numpy", 44 | "lhotse", 45 | ], 46 | extras_require={ 47 | "dev": dev_requires, 48 | }, 49 | include_dirs=[numpy.get_include()], 50 | entry_points={ 51 | "console_scripts": [ 52 | "gss=gss.bin.gss:cli", 53 | ] 54 | }, 55 | ) 56 | -------------------------------------------------------------------------------- /gss/core/gss.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import cupy as cp 4 | 5 | from gss.cacgmm import CACGMMTrainer 6 | 7 | 8 | @dataclass 9 | class GSS: 10 | iterations: int 11 | iterations_post: int 12 | 13 | def __call__(self, Obs, activity_freq): 14 | D, T, F = Obs.shape 15 | initialization = cp.asarray(activity_freq, dtype=cp.float64) 16 | initialization = cp.where(initialization == 0, 1e-10, initialization) 17 | initialization = initialization / cp.sum(initialization, keepdims=True, axis=0) 18 | initialization = cp.repeat(initialization[None, ...], F, axis=0) 19 | 20 | source_active_mask = cp.asarray(activity_freq, dtype=cp.bool_) 21 | source_active_mask = cp.repeat(source_active_mask[None, ...], F, axis=0) 22 | 23 | cacGMM = CACGMMTrainer() 24 | 25 | cur = cacGMM.fit( 26 | y=Obs.T, 27 | initialization=initialization[..., :T], 28 | iterations=self.iterations, 29 | source_activity_mask=source_active_mask[..., :T], 30 | ) 31 | 32 | if self.iterations_post != 0: 33 | if self.iterations_post != 1: 34 | cur = cacGMM.fit( 35 | y=Obs.T, 36 | initialization=cur, 37 | iterations=self.iterations_post - 1, 38 | ) 39 | affiliation = cur.predict(Obs.T) 40 | else: 41 | affiliation = cur.predict( 42 | Obs.T, source_activity_mask=source_active_mask[..., :T] 43 | ) 44 | 45 | posterior = affiliation.transpose(1, 2, 0) 46 | 47 | return posterior 48 | -------------------------------------------------------------------------------- /recipes/ami/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=8 5 | stage=0 6 | 7 | . ./path.sh 8 | . parse_options.sh 9 | 10 | CORPUS_DIR=/export/corpora5/amicorpus 11 | DATA_DIR=data/ 12 | EXP_DIR=exp/ami_train 13 | 14 | cmd="queue-ackgpu.pl --gpu 1 --mem 4G --config conf/gpu.conf" 15 | 16 | if [ $stage -le 0 ]; then 17 | echo "Stage 0: Prepare manifests" 18 | lhotse prepare ami --mic mdm --partition full-corpus-asr $CORPUS_DIR $DATA_DIR 19 | fi 20 | 21 | if [ $stage -le 1 ]; then 22 | echo "Stage 1: Prepare cut set" 23 | lhotse cut simple \ 24 | -r $DATA_DIR/ami-mdm_recordings_train.jsonl.gz \ 25 | -s $DATA_DIR/ami-mdm_supervisions_train.jsonl.gz \ 26 | $EXP_DIR/cuts.jsonl.gz 27 | fi 28 | 29 | if [ $stage -le 2 ]; then 30 | echo "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" 31 | lhotse cut trim-to-supervisions --discard-overlapping \ 32 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/cuts_per_segment.jsonl.gz 33 | fi 34 | 35 | if [ $stage -le 3 ]; then 36 | echo "Stage 3: Split segments into $nj parts" 37 | gss utils split $nj $EXP_DIR/cuts_per_segment.jsonl.gz $EXP_DIR/split$nj 38 | fi 39 | 40 | if [ $stage -le 4 ]; then 41 | echo "Stage 4: Enhance segments using GSS" 42 | $cmd JOB=1:$nj $EXP_DIR/log/enhance.JOB.log \ 43 | gss enhance cuts \ 44 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/split$nj/cuts_per_segment.JOB.jsonl.gz \ 45 | $EXP_DIR/enhanced \ 46 | --num-channels 7 \ 47 | --bss-iterations 10 \ 48 | --min-segment-length 0.0 \ 49 | --max-segment-length 15.0 50 | fi 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # PyCharm 2 | .idea/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # Experiment 110 | dask-worker-space/ 111 | exp/ 112 | data/ 113 | egs/*/data 114 | egs/*/exp 115 | -------------------------------------------------------------------------------- /.github/workflows/style_check.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) 2 | # 2022 Desh Raj (r.desh26@gmail.com) 3 | 4 | # See ../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | name: style_check 19 | 20 | on: 21 | push: 22 | branches: 23 | - master 24 | pull_request: 25 | branches: 26 | - master 27 | 28 | jobs: 29 | style_check: 30 | runs-on: ${{ matrix.os }} 31 | strategy: 32 | matrix: 33 | os: [ubuntu-latest] 34 | python-version: [3.8] 35 | fail-fast: false 36 | 37 | steps: 38 | - uses: actions/checkout@v2 39 | with: 40 | fetch-depth: 0 41 | 42 | - name: Setup Python ${{ matrix.python-version }} 43 | uses: actions/setup-python@v1 44 | with: 45 | python-version: ${{ matrix.python-version }} 46 | 47 | - name: Install Python dependencies 48 | run: | 49 | python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 50 | # Click issue fixed in https://github.com/psf/black/pull/2966 51 | - name: Run flake8 52 | shell: bash 53 | working-directory: ${{github.workspace}} 54 | run: | 55 | # stop the build if there are Python syntax errors or undefined names 56 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 57 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 58 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ 59 | --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 60 | - name: Run black 61 | shell: bash 62 | working-directory: ${{github.workspace}} 63 | run: | 64 | black --check --diff . 65 | -------------------------------------------------------------------------------- /recipes/dipco/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=4 5 | affix="" 6 | stage=0 7 | stop_stage=100 8 | 9 | . ./path.sh 10 | . parse_options.sh 11 | 12 | # Append _ to affix if not empty 13 | affix=${affix:+_$affix} 14 | 15 | CORPUS_DIR=/export/corpora6/DiPCo/DiPCo 16 | DATA_DIR=data/ 17 | EXP_DIR=exp/dipco${affix} 18 | 19 | cmd="queue-ackgpu.pl --gpu 1 --mem 8G --config conf/gpu.conf" 20 | 21 | mkdir -p $DATA_DIR 22 | mkdir -p $EXP_DIR/{dev,eval} 23 | 24 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 25 | echo "Stage 0: Prepare manifests" 26 | lhotse prepare dipco --mic mdm $CORPUS_DIR $DATA_DIR 27 | fi 28 | 29 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 30 | echo "Stage 1: Prepare cut set" 31 | for part in dev eval; do 32 | lhotse cut simple --force-eager \ 33 | -r $DATA_DIR/dipco-mdm_recordings_${part}.jsonl.gz \ 34 | -s $DATA_DIR/dipco-mdm_supervisions_${part}.jsonl.gz \ 35 | $EXP_DIR/$part/cuts.jsonl.gz 36 | done 37 | fi 38 | 39 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 40 | echo "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" 41 | for part in dev eval; do 42 | lhotse cut trim-to-supervisions --discard-overlapping \ 43 | $EXP_DIR/$part/cuts.jsonl.gz $EXP_DIR/$part/cuts_per_segment.jsonl.gz 44 | done 45 | fi 46 | 47 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 48 | echo "Stage 3: Split segments into $nj parts" 49 | for part in dev eval; do 50 | gss utils split $nj $EXP_DIR/$part/cuts_per_segment.jsonl.gz $EXP_DIR/$part/split$nj 51 | done 52 | fi 53 | 54 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 55 | echo "Stage 4: Enhance segments using GSS (central array mics)" 56 | for part in dev eval; do 57 | $cmd JOB=1:$nj $EXP_DIR/$part/log/enhance.JOB.log \ 58 | gss enhance cuts \ 59 | $EXP_DIR/$part/cuts.jsonl.gz $EXP_DIR/$part/split$nj/cuts_per_segment.JOB.jsonl.gz \ 60 | $EXP_DIR/$part/enhanced \ 61 | --channels 0,7,14,21,28 \ 62 | --bss-iterations 20 \ 63 | --context-duration 15.0 \ 64 | --use-garbage-class \ 65 | --min-segment-length 0.0 \ 66 | --max-segment-length 20.0 \ 67 | --max-batch-duration 20.0 \ 68 | --num-buckets 4 \ 69 | --num-workers 4 \ 70 | --force-overwrite \ 71 | --duration-tolerance 3.0 || exit 1 72 | done 73 | fi 74 | -------------------------------------------------------------------------------- /recipes/libricss/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=8 5 | rttm_dir="" 6 | affix="oracle" 7 | stage=0 8 | stop_stage=100 9 | 10 | . ./path.sh 11 | . parse_options.sh 12 | 13 | CORPUS_DIR=/export/fs01/LibriCSS 14 | DATA_DIR=data/ 15 | EXP_DIR=exp/libricss_${affix} 16 | 17 | cmd="queue-ackgpu.pl --gpu 1 --mem 4G --config conf/gpu.conf" 18 | 19 | mkdir -p $DATA_DIR 20 | mkdir -p $EXP_DIR 21 | 22 | if [ -z $rttm_dir ]; then 23 | supervisions_path=$DATA_DIR/libricss_supervisions_all.jsonl.gz 24 | else 25 | supervisions_path=$EXP_DIR/supervisions_${affix}.jsonl.gz 26 | fi 27 | 28 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 29 | echo "Stage 0: Prepare manifests" 30 | lhotse prepare libricss --type mdm $CORPUS_DIR $DATA_DIR 31 | fi 32 | 33 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! -z $rttm_dir ]; then 34 | echo "Stage 1: Create supervisions from RTTM file" 35 | gss utils rttm-to-supervisions --channels 7 $rttm_dir $supervisions_path 36 | fi 37 | 38 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 39 | echo "Stage 2: Prepare cut set" 40 | # --force-eager must be set if recordings are not sorted by id 41 | lhotse cut simple --force-eager \ 42 | -r $DATA_DIR/libricss_recordings_all.jsonl.gz \ 43 | -s $supervisions_path \ 44 | $EXP_DIR/cuts.jsonl.gz 45 | fi 46 | 47 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 48 | echo "Stage 3: Trim cuts to supervisions (1 cut per supervision segment)" 49 | lhotse cut trim-to-supervisions --discard-overlapping \ 50 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/cuts_per_segment.jsonl.gz 51 | fi 52 | 53 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 54 | echo "Stage 4: Split segments into $nj parts" 55 | gss utils split $nj $EXP_DIR/cuts_per_segment.jsonl.gz $EXP_DIR/split$nj 56 | fi 57 | 58 | if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then 59 | echo "Stage 5: Enhance segments using GSS" 60 | $cmd JOB=1:$nj $EXP_DIR/log/enhance.JOB.log \ 61 | gss enhance cuts \ 62 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/split$nj/cuts_per_segment.JOB.jsonl.gz \ 63 | $EXP_DIR/enhanced \ 64 | --use-garbage-class \ 65 | --channels 0,1,2,3,4,5,6 \ 66 | --bss-iterations 10 \ 67 | --context-duration 15.0 \ 68 | --min-segment-length 0.1 \ 69 | --max-segment-length 15.0 \ 70 | --max-batch-duration 20.0 \ 71 | --num-buckets 3 \ 72 | --force-overwrite 73 | fi 74 | -------------------------------------------------------------------------------- /recipes/chime6/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=4 5 | affix="" 6 | stage=0 7 | stop_stage=100 8 | 9 | . ./path.sh 10 | . parse_options.sh 11 | 12 | # Append _ to affix if not empty 13 | affix=${affix:+_$affix} 14 | 15 | CORPUS_DIR=/expscratch/mwiesner/CHiME6 16 | DATA_DIR=data/ 17 | EXP_DIR=exp/chime6${affix} 18 | 19 | cmd="queue-freegpu.pl --gpu 1 --mem 8G --config conf/gpu.conf" 20 | 21 | mkdir -p $DATA_DIR 22 | mkdir -p $EXP_DIR/{dev,eval} 23 | 24 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 25 | echo "Stage 0: Prepare manifests" 26 | lhotse prepare chime6 --mic mdm -p dev -p eval $CORPUS_DIR $DATA_DIR 27 | fi 28 | 29 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 30 | echo "Stage 1: Prepare cut set" 31 | for part in dev eval; do 32 | lhotse cut simple --force-eager \ 33 | -r $DATA_DIR/chime6-mdm_recordings_${part}.jsonl.gz \ 34 | -s $DATA_DIR/chime6-mdm_supervisions_${part}.jsonl.gz \ 35 | $EXP_DIR/$part/cuts.jsonl.gz 36 | done 37 | fi 38 | 39 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 40 | echo "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" 41 | for part in dev eval; do 42 | lhotse cut trim-to-supervisions --discard-overlapping \ 43 | $EXP_DIR/$part/cuts.jsonl.gz $EXP_DIR/$part/cuts_per_segment.jsonl.gz 44 | done 45 | fi 46 | 47 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 48 | echo "Stage 3: Split segments into $nj parts" 49 | for part in dev eval; do 50 | gss utils split $nj $EXP_DIR/$part/cuts_per_segment.jsonl.gz $EXP_DIR/$part/split$nj 51 | done 52 | fi 53 | 54 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 55 | echo "Stage 4: Enhance segments using GSS (outer array mics)" 56 | # NOTE: U03 is missing is S01 and U05 is missing in S09, so we only use 57 | # 10 channels here instead of 12. 58 | for part in dev; do 59 | $cmd JOB=1:$nj $EXP_DIR/$part/log/enhance.JOB.log \ 60 | gss enhance cuts \ 61 | $EXP_DIR/$part/cuts.jsonl.gz $EXP_DIR/$part/split$nj/cuts_per_segment.JOB.jsonl.gz \ 62 | $EXP_DIR/$part/enhanced \ 63 | --bss-iterations 20 \ 64 | --context-duration 15.0 \ 65 | --use-garbage-class \ 66 | --min-segment-length 0.0 \ 67 | --max-segment-length 20.0 \ 68 | --max-batch-duration 30.0 \ 69 | --max-batch-cuts 1 \ 70 | --num-buckets 4 \ 71 | --num-workers 4 \ 72 | --force-overwrite \ 73 | --duration-tolerance 3.0 || exit 1 74 | done 75 | fi 76 | -------------------------------------------------------------------------------- /recipes/alimeeting/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=8 5 | rttm_dir="" 6 | affix="_oracle" 7 | stage=0 8 | stop_stage=100 9 | 10 | . ./path.sh 11 | . parse_options.sh 12 | 13 | CORPUS_DIR=/export/c01/corpora6/AliMeeting 14 | DATA_DIR=data/ 15 | EXP_DIR=exp/alimeeting${affix} 16 | 17 | cmd="queue-ackgpu.pl --gpu 1 --mem 4G --config conf/gpu.conf" 18 | 19 | mkdir -p $DATA_DIR 20 | mkdir -p $EXP_DIR 21 | 22 | if [ -z $rttm_dir ]; then 23 | supervisions_path=$DATA_DIR/alimeeting_supervisions_all.jsonl.gz 24 | else 25 | supervisions_path=$EXP_DIR/supervisions_${rttm_tag}.jsonl.gz 26 | fi 27 | 28 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 29 | echo "Stage 0: Prepare manifests" 30 | lhotse prepare ali-meeting --mic far $CORPUS_DIR $DATA_DIR 31 | lhotse combine $DATA_DIR/alimeeting_recordings_{eval,test}.jsonl.gz data/alimeeting_recordings_all.jsonl.gz 32 | lhotse combine $DATA_DIR/alimeeting_supervisions_{eval,test}.jsonl.gz data/alimeeting_supervisions_all.jsonl.gz 33 | fi 34 | 35 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! -z $rttm_dir ]; then 36 | echo "Stage 1: Create supervisions from RTTM file" 37 | gss utils rttm-to-supervisions --channels 7 $rttm_dir $supervisions_path 38 | fi 39 | 40 | if [ $stage -le 2 ] && [ $stop_stage -ge 1 ]; then 41 | echo "Stage 2: Prepare cut set" 42 | lhotse cut simple --force-eager \ 43 | -r $DATA_DIR/alimeeting_recordings_all.jsonl.gz \ 44 | -s $supervisions_path \ 45 | $EXP_DIR/cuts.jsonl.gz 46 | fi 47 | 48 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 49 | echo "Stage 3: Trim cuts to supervisions (1 cut per supervision segment)" 50 | lhotse cut trim-to-supervisions --discard-overlapping \ 51 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/cuts_per_segment.jsonl.gz 52 | fi 53 | 54 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 55 | echo "Stage 4: Split segments into $nj parts" 56 | gss utils split $nj $EXP_DIR/cuts_per_segment.jsonl.gz $EXP_DIR/split$nj 57 | fi 58 | 59 | if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then 60 | echo "Stage 5: Enhance segments using GSS" 61 | $cmd JOB=1:$nj $EXP_DIR/log/enhance.JOB.log \ 62 | gss enhance cuts \ 63 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/split$nj/cuts_per_segment.JOB.jsonl.gz \ 64 | $EXP_DIR/enhanced \ 65 | --bss-iterations 10 \ 66 | --min-segment-length 0.0 \ 67 | --max-segment-length 15.0 \ 68 | --max-batch-duration 20.0 \ 69 | --num-buckets 3 \ 70 | --enhanced-manifest $EXP_DIR/split$nj/cuts_enhanced.JOB.jsonl.gz 71 | 72 | echo "Stage 5: Combine enhanced cuts" 73 | lhotse combine $EXP_DIR/split$nj/cuts_enhanced.*.jsonl.gz $EXP_DIR/cuts_enhanced.jsonl.gz 74 | fi 75 | -------------------------------------------------------------------------------- /recipes/ami/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used to run the enhancement. 3 | set -euo pipefail 4 | nj=8 5 | rttm_dir="" 6 | affix="_oracle" 7 | stage=0 8 | stop_stage=100 9 | 10 | . ./path.sh 11 | . parse_options.sh 12 | 13 | CORPUS_DIR=/export/corpora5/amicorpus 14 | DATA_DIR=data/ 15 | EXP_DIR=exp/ami${affix} 16 | 17 | cmd="queue-ackgpu.pl --gpu 1 --mem 4G --config conf/gpu.conf" 18 | 19 | mkdir -p $DATA_DIR 20 | mkdir -p $EXP_DIR 21 | 22 | if [ -z $rttm_dir ]; then 23 | supervisions_path=$DATA_DIR/ami_supervisions_all.jsonl.gz 24 | else 25 | supervisions_path=$EXP_DIR/supervisions_${rttm_tag}.jsonl.gz 26 | fi 27 | 28 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 29 | echo "Stage 0: Prepare manifests" 30 | lhotse prepare ami --mic mdm --partition full-corpus-asr $CORPUS_DIR $DATA_DIR 31 | lhotse combine $DATA_DIR/ami-mdm_recordings_{dev,test}.jsonl.gz data/ami-mdm_recordings_all.jsonl.gz 32 | lhotse combine $DATA_DIR/ami-mdm_supervisions_{dev,test}.jsonl.gz data/ami-mdm_supervisions_all.jsonl.gz 33 | fi 34 | 35 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! -z $rttm_dir ]; then 36 | echo "Stage 1: Create supervisions from RTTM file" 37 | gss utils rttm-to-supervisions --channels 7 $rttm_dir $supervisions_path 38 | fi 39 | 40 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 41 | echo "Stage 2: Prepare cut set" 42 | lhotse cut simple --force-eager \ 43 | -r $DATA_DIR/ami-mdm_recordings_all.jsonl.gz \ 44 | -s $supervisions_path \ 45 | $EXP_DIR/cuts.jsonl.gz 46 | fi 47 | 48 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 49 | echo "Stage 3: Trim cuts to supervisions (1 cut per supervision segment)" 50 | lhotse cut trim-to-supervisions --discard-overlapping \ 51 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/cuts_per_segment.jsonl.gz 52 | fi 53 | 54 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 55 | echo "Stage 4: Split segments into $nj parts" 56 | gss utils split $nj $EXP_DIR/cuts_per_segment.jsonl.gz $EXP_DIR/split$nj 57 | fi 58 | 59 | if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then 60 | echo "Stage 5: Enhance segments using GSS" 61 | $cmd JOB=1:$nj $EXP_DIR/log/enhance.JOB.log \ 62 | gss utils gpu_check $nj $cmd \& gss enhance cuts \ 63 | $EXP_DIR/cuts.jsonl.gz $EXP_DIR/split$nj/cuts_per_segment.JOB.jsonl.gz \ 64 | $EXP_DIR/enhanced \ 65 | --channels 0,1,2,3,4,5,6,7 \ 66 | --bss-iterations 10 \ 67 | --min-segment-length 0.0 \ 68 | --max-segment-length 15.0 \ 69 | --max-batch-duration 20.0 \ 70 | --num-buckets 3 \ 71 | --enhanced-manifest $EXP_DIR/split$nj/cuts_enhanced.JOB.jsonl.gz 72 | 73 | echo "Stage 5: Combine enhanced cuts" 74 | lhotse combine $EXP_DIR/split$nj/cuts_enhanced.*.jsonl.gz $EXP_DIR/cuts_enhanced.jsonl.gz 75 | fi 76 | -------------------------------------------------------------------------------- /recipes/libricss/utils/acquire-gpu: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Author: Guanghui Qin 3 | # Report bugs to Guanghui Qin (via slack or gqin@jhu.edu) 4 | 5 | # It automatically detect the idle CUDA devices. If available, it locks one of them and append 6 | # the device number to the CUDA_VISIBLE_DEVICES variable. 7 | # A device is considered available if (1) no job is running on this device; (2) not locked. 8 | # Usage: (Suppose you need only one CUDA) 9 | # >>> source acquire-gpu 10 | # >>> # Your codes start here. 11 | # If you need more than 1 CUDA, simply use a for loop or repeat the script n times. 12 | # Suppose you need 3 CUDA devices: 13 | # >>> for _ in $(seq 3); do source acquire-gpu; done 14 | # Be sure to use `source` instead of `bash` 15 | 16 | # The mechanism of acquire-gpu is to associate every GPU with a file. A job locks the file if 17 | # it needs to use that device, and release the lock when it exits, either normally or abnormally. 18 | 19 | # Copy this script to your local and remove this line if you don't want to be tracked. 20 | # [ -w /home/gqin2/.track/acquire-gpu.track ] && echo "$USER $(hostname) $(date +'%Y-%m-%d %H:%M:%S')" >> /home/gqin2/.track/acquire-gpu.track 21 | 22 | # Create the lock folder if it doesn't exist. 23 | LOCK_DIR="$HOME/.lock" 24 | mkdir -p "$LOCK_DIR" 25 | if [ ! -f "$LOCK_DIR/master.lock" ]; then 26 | touch "$LOCK_DIR/master.lock" 27 | fi 28 | # master lock is used to prevent racing between different programs running this script. 29 | exec {MASTER_FN}>"$LOCK_DIR/master.lock" 30 | flock -x $MASTER_FN 31 | 32 | N_GPU=$(nvidia-smi -L | wc -l) 33 | echo "Number of GPUs: $N_GPU" 34 | 35 | # Parse nvidia-smi to get available devices 36 | FREE_GPU=$(nvidia-smi | sed -e '1,/Processes/d' | tail -n+3 | head -n-1 | awk '{print $2}'\ 37 | | awk -v ng="$N_GPU" 'BEGIN{for (n=0;n"$TMP_LOCK" 49 | if ! flock -xn $TMP_FN ; then 50 | # echo "BUSY: $DEVICE_ID" 51 | FREE_GPU=$(sed "/$DEVICE_ID/d" <<< "$FREE_GPU") 52 | fi 53 | exec {TMP_FN}>&- 54 | done 55 | 56 | echo "Free GPUs: " $FREE_GPU 57 | 58 | # Passing parameter to this script to indicate the number of GPUs is dangerous, which 59 | # might collide with the outside environment variable since we're sourcing this script. 60 | # I deprecated the method, and acquire exactly 1 device each time. 61 | SELECTED_DEVICES=$(head -n 1 <<< "$FREE_GPU") 62 | 63 | # Lock the device 64 | for DEVICE in $SELECTED_DEVICES; do 65 | echo "Select device: $DEVICE" 66 | LOCK="$LOCK_DIR/$(cat /etc/hostname).$DEVICE.lock" 67 | touch "$LOCK" 68 | exec {CUR_FN}> $LOCK 69 | flock -x $CUR_FN 70 | CUDA_VISIBLE_DEVICES+="$DEVICE," 71 | done 72 | 73 | export CUDA_VISIBLE_DEVICES 74 | 75 | exec {MASTER_FN}>&- 76 | -------------------------------------------------------------------------------- /gss/beamformer/souden_mvdr.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | from cached_property import cached_property 3 | 4 | from gss.beamformer.beamform import ( 5 | apply_beamforming_vector, 6 | blind_analytic_normalization, 7 | get_mvdr_vector_souden, 8 | get_power_spectral_density_matrix, 9 | ) 10 | from gss.beamformer.utils import morph 11 | 12 | # The _Beamformer class is modified from: 13 | # https://github.com/fgnt/pb_chime5/blob/master/pb_chime5/speech_enhancement/beamforming_wrapper.py 14 | 15 | 16 | class _Beamformer: 17 | def __init__( 18 | self, 19 | Y, 20 | X_mask, 21 | N_mask, 22 | ): 23 | if cp.ndim(Y) == 4: 24 | self.Y = morph("1DTF->FDT", Y) 25 | else: 26 | self.Y = morph("DTF->FDT", Y) 27 | 28 | if cp.ndim(X_mask) == 4: 29 | self.X_mask = morph("1DTF->FT", X_mask, reduce=cp.median) 30 | self.N_mask = morph("1DTF->FT", N_mask, reduce=cp.median) 31 | elif cp.ndim(X_mask) == 3: 32 | self.X_mask = morph("DTF->FT", X_mask, reduce=cp.median) 33 | self.N_mask = morph("DTF->FT", N_mask, reduce=cp.median) 34 | elif cp.ndim(X_mask) == 2: 35 | self.X_mask = morph("TF->FT", X_mask, reduce=cp.median) 36 | self.N_mask = morph("TF->FT", N_mask, reduce=cp.median) 37 | else: 38 | raise NotImplementedError(X_mask.shape) 39 | 40 | assert self.Y.ndim == 3, self.Y.shape 41 | F, D, T = self.Y.shape 42 | assert D < 30, (D, self.Y.shape) 43 | assert self.X_mask.shape == (F, T), (self.X_mask.shape, F, T) 44 | assert self.N_mask.shape == (F, T), (self.N_mask.shape, F, T) 45 | 46 | @cached_property 47 | def _Cov_X(self): 48 | Cov_X = get_power_spectral_density_matrix(self.Y, self.X_mask) 49 | return Cov_X 50 | 51 | @cached_property 52 | def _Cov_N(self): 53 | Cov_N = get_power_spectral_density_matrix(self.Y, self.N_mask) 54 | return Cov_N 55 | 56 | @cached_property 57 | def _w_mvdr_souden(self): 58 | w_mvdr_souden = get_mvdr_vector_souden(self._Cov_X, self._Cov_N, eps=1e-10) 59 | return w_mvdr_souden 60 | 61 | @cached_property 62 | def _w_mvdr_souden_ban(self): 63 | w_mvdr_souden_ban = blind_analytic_normalization( 64 | self._w_mvdr_souden, self._Cov_N 65 | ) 66 | return w_mvdr_souden_ban 67 | 68 | @cached_property 69 | def X_hat_mvdr_souden(self): 70 | return apply_beamforming_vector(self._w_mvdr_souden, self.Y).T 71 | 72 | @cached_property 73 | def X_hat_mvdr_souden_ban(self): 74 | return apply_beamforming_vector(self._w_mvdr_souden_ban, self.Y).T 75 | 76 | 77 | def beamform_mvdr(Y, X_mask, N_mask, ban=False): 78 | """ 79 | Souden MVDR beamformer. 80 | Args: 81 | Y: CuPy array of shape (channel, time, frequency). 82 | X_mask: CuPy array of shape (time, frequency). 83 | N_mask: CuPy array of shape (time, frequency). 84 | ban: If True, use blind analytic normalization. 85 | Returns: 86 | X_hat: Beamformed signal, CuPy array of shape (time, frequency). 87 | """ 88 | bf = _Beamformer( 89 | Y=Y, 90 | X_mask=X_mask, 91 | N_mask=N_mask, 92 | ) 93 | if ban: 94 | return bf.X_hat_mvdr_souden_ban 95 | else: 96 | return bf.X_hat_mvdr_souden 97 | -------------------------------------------------------------------------------- /gss/cacgmm/cacgmm.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from: 2 | # https://github.com/fgnt/pb_bss/blob/master/pb_bss/distribution/cacgmm.py 3 | 4 | from dataclasses import dataclass, field 5 | 6 | import cupy as cp 7 | 8 | from gss.cacgmm.cacg import ComplexAngularCentralGaussian 9 | from gss.cacgmm.utils import log_pdf_to_affiliation, normalize_observation 10 | 11 | 12 | def logsumexp(a, axis=None): 13 | a_max = cp.amax(a, axis=axis, keepdims=True) 14 | 15 | if a_max.ndim > 0: 16 | a_max[~cp.isfinite(a_max)] = 0 17 | elif not cp.isfinite(a_max): 18 | a_max = 0 19 | 20 | tmp = cp.exp(a - a_max) 21 | 22 | # suppress warnings about log of zero 23 | with cp.errstate(divide="ignore"): 24 | s = cp.sum(tmp, axis=axis, keepdims=False) 25 | out = cp.log(s) 26 | 27 | a_max = cp.squeeze(a_max, axis=axis) 28 | out += a_max 29 | 30 | return out 31 | 32 | 33 | @dataclass 34 | class CACGMM: 35 | weight: cp.array = None # (..., K, 1) for weight_constant_axis==(-1,) (..., 1, K, T) for weight_constant_axis==(-3,) 36 | cacg: ComplexAngularCentralGaussian = field( 37 | default_factory=ComplexAngularCentralGaussian 38 | ) 39 | 40 | def predict(self, y, return_quadratic_form=False, source_activity_mask=None): 41 | assert cp.iscomplexobj(y), y.dtype 42 | y = normalize_observation(y) # swap D and T dim 43 | affiliation, quadratic_form, _ = self._predict( 44 | y, source_activity_mask=source_activity_mask 45 | ) 46 | if return_quadratic_form: 47 | return affiliation, quadratic_form 48 | else: 49 | return affiliation 50 | 51 | def _predict(self, y, source_activity_mask=None, affiliation_eps=0.0): 52 | """ 53 | Note: y shape is (..., D, T) and not (..., T, D) like in predict 54 | Args: 55 | y: Normalized observations with shape (..., D, T). 56 | Returns: Affiliations with shape (..., K, T) and quadratic format 57 | with the same shape. 58 | """ 59 | *independent, _, num_observations = y.shape 60 | 61 | log_pdf, quadratic_form = self.cacg._log_pdf(y[..., None, :, :]) 62 | 63 | affiliation = log_pdf_to_affiliation( 64 | self.weight, 65 | log_pdf, 66 | source_activity_mask=source_activity_mask, 67 | affiliation_eps=affiliation_eps, 68 | ) 69 | 70 | return affiliation, quadratic_form, log_pdf 71 | 72 | def log_likelihood(self, y): 73 | assert cp.iscomplexobj(y), y.dtype 74 | y = normalize_observation(y) # swap D and T dim 75 | affiliation, quadratic_form, log_pdf = self._predict(y) 76 | return self._log_likelihood(y, log_pdf) 77 | 78 | def _log_likelihood(self, y, log_pdf): 79 | """ 80 | Note: y shape is (..., D, T) and not (..., T, D) like in log_likelihood 81 | Args: 82 | y: Normalized observations with shape (..., D, T). 83 | log_pdf: shape (..., K, T) 84 | Returns: 85 | log_likelihood, scalar 86 | """ 87 | *independent, channels, num_observations = y.shape 88 | 89 | # log_pdf.shape: *independent, speakers, num_observations 90 | 91 | # first: sum above the speakers 92 | # second: sum above time frequency in log domain 93 | log_likelihood = cp.sum(logsumexp(log_pdf, axis=-2)) 94 | return log_likelihood 95 | -------------------------------------------------------------------------------- /gss/cacgmm/utils.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | 3 | 4 | def _unit_norm(signal, *, axis=-1, eps=1e-4, eps_style="plus", ord=None): 5 | norm = cp.linalg.norm(signal, ord=ord, axis=axis, keepdims=True) 6 | if eps_style == "plus": 7 | norm = norm + eps 8 | elif eps_style == "max": 9 | norm = cp.maximum(norm, eps) 10 | elif eps_style == "where": 11 | norm = cp.where(norm == 0, eps, norm) 12 | else: 13 | assert False, eps_style 14 | return signal / norm 15 | 16 | 17 | def force_hermitian(matrix): 18 | return (matrix + cp.swapaxes(matrix.conj(), -1, -2)) / 2 19 | 20 | 21 | def estimate_mixture_weight( 22 | affiliation, 23 | saliency=None, 24 | weight_constant_axis=-1, 25 | ): 26 | affiliation = cp.asarray(affiliation) 27 | 28 | if ( 29 | isinstance(weight_constant_axis, int) 30 | and weight_constant_axis % affiliation.ndim - affiliation.ndim == -2 31 | ): 32 | K = affiliation.shape[-2] 33 | return cp.full([K, 1], 1 / K) 34 | elif isinstance(weight_constant_axis, list): 35 | weight_constant_axis = tuple(weight_constant_axis) 36 | 37 | if saliency is None: 38 | weight = cp.mean(affiliation, axis=weight_constant_axis, keepdims=True) 39 | else: 40 | masked_affiliation = affiliation * saliency[..., None, :] 41 | weight = _unit_norm( 42 | cp.sum(masked_affiliation, axis=weight_constant_axis, keepdims=True), 43 | ord=1, 44 | axis=-2, 45 | eps=1e-10, 46 | eps_style="where", 47 | ) 48 | 49 | return weight 50 | 51 | 52 | def log_pdf_to_affiliation( 53 | weight, 54 | log_pdf, 55 | source_activity_mask=None, 56 | affiliation_eps=0.0, 57 | ): 58 | # Only check broadcast compatibility 59 | if source_activity_mask is None: 60 | _ = cp.broadcast_arrays(weight, log_pdf) 61 | else: 62 | _ = cp.broadcast_arrays(weight, log_pdf, source_activity_mask) 63 | 64 | # The value of affiliation max may exceed float64 range. 65 | # Scaling (add in log domain) does not change the final affiliation. 66 | affiliation = log_pdf - cp.amax(log_pdf, axis=-2, keepdims=True) 67 | 68 | cp.exp(affiliation, out=affiliation) 69 | 70 | # Weight multiplied not in log domain to avoid logarithm of zero. 71 | affiliation *= weight 72 | 73 | if source_activity_mask is not None: 74 | assert ( 75 | source_activity_mask.dtype == cp.bool_ 76 | ), source_activity_mask.dtype # noqa 77 | affiliation *= source_activity_mask 78 | 79 | denominator = cp.maximum( 80 | cp.sum(affiliation, axis=-2, keepdims=True), 81 | cp.finfo(affiliation.dtype).tiny, 82 | ) 83 | affiliation /= denominator 84 | 85 | # Strictly, you need re-normalization after clipping. We skip that here. 86 | if affiliation_eps != 0: 87 | affiliation = cp.clip( 88 | affiliation, 89 | affiliation_eps, 90 | 1 - affiliation_eps, 91 | ) 92 | 93 | return affiliation 94 | 95 | 96 | def is_broadcast_compatible(*shapes): 97 | if len(shapes) < 2: 98 | return True 99 | else: 100 | for dim in zip(*[shape[::-1] for shape in shapes]): 101 | if len(set(dim).union({1})) <= 2: 102 | pass 103 | else: 104 | return False 105 | return True 106 | 107 | 108 | def normalize_observation(observation): 109 | """ 110 | Attention: swap D and T dim 111 | """ 112 | observation = _unit_norm( 113 | observation, 114 | axis=-1, 115 | eps=cp.finfo(observation.dtype).tiny, 116 | eps_style="where", 117 | ) 118 | return cp.ascontiguousarray(cp.swapaxes(observation, -2, -1)) 119 | -------------------------------------------------------------------------------- /recipes/libricss/utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /gss/bin/modes/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import subprocess 5 | from pathlib import Path 6 | 7 | import click 8 | 9 | from gss.bin.modes.cli_base import cli 10 | 11 | logging.basicConfig( 12 | format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 13 | datefmt="%Y-%m-%d:%H:%M:%S", 14 | level=logging.INFO, 15 | ) 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @cli.group() 20 | def utils(): 21 | """General utilities for manipulating manifests.""" 22 | pass 23 | 24 | 25 | @utils.command(name="rttm-to-supervisions") 26 | @click.argument("rttm_path", type=click.Path(exists=True)) 27 | @click.argument("out_path", type=click.Path()) 28 | @click.option( 29 | "--channels", 30 | "-c", 31 | type=int, 32 | default=1, 33 | help="Number of channels in the recording (supervisions will be modified to contain all these channels).", 34 | ) 35 | def rttm_to_supervisions_(rttm_path, out_path, channels): 36 | """ 37 | Convert RTTM file to Supervisions manifest. 38 | """ 39 | from lhotse import SupervisionSet 40 | from lhotse.utils import fastcopy 41 | 42 | rttm_path = Path(rttm_path) 43 | rttm_files = rttm_path if rttm_path.is_file() else rttm_path.rglob("*.rttm") 44 | supervisions = SupervisionSet.from_rttm(rttm_files) 45 | # Supervisions obtained from RTTM files are single-channel only, so we modify the 46 | # ``channel`` field to share it for all channels. 47 | supervisions = SupervisionSet.from_segments( 48 | [fastcopy(s, channel=list(range(channels))) for s in supervisions] 49 | ) 50 | supervisions.to_file(out_path) 51 | 52 | 53 | @utils.command(name="gpu_check") 54 | @click.argument("num_jobs", type=int) 55 | @click.argument("cmd", type=str) 56 | def gpu_check_(num_jobs, cmd): 57 | if cmd == "run.pl" and num_jobs > 1: 58 | used_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 59 | assert num_jobs <= len(used_devices), f"You are requesting {num_jobs} jobs but you have {len(used_devices)} GPUs available. Exiting !" 60 | for device in used_devices: 61 | grep_res = subprocess.check_output(("nvidia-smi", "-i", f"{device}", "-q")) 62 | check = re.findall("Compute Mode\s+:\sDefault", str(grep_res)) 63 | if not len(check) == 0: 64 | logging.error( 65 | "This code may not work as expected with multiple GPUs " 66 | f"and non exclusive process compute mode." 67 | f" GPU {device} is in Default mode." 68 | f" Please switch compute mode using nvidia-smi." 69 | ) 70 | raise RuntimeError( 71 | f"GPU {device} not in exclusive process compute mode." 72 | ) 73 | 74 | 75 | @utils.command(name="split") 76 | @click.argument("num_splits", type=int) 77 | @click.argument( 78 | "manifest", type=click.Path(exists=True, dir_okay=False, allow_dash=True) 79 | ) 80 | @click.argument("output_dir", type=click.Path()) 81 | def split_(num_splits, manifest, output_dir): 82 | """ 83 | This is similar to Lhotse's split command, but we additionally try to ensure that 84 | cuts from the same recording and speaker stay in the same split as much as possible. 85 | This is done by sorting the cuts by recording ID and speaker ID, and then splitting 86 | them into chunks of approximately equal size. 87 | """ 88 | from lhotse import CutSet 89 | from lhotse.serialization import load_manifest_lazy_or_eager 90 | 91 | output_dir = Path(output_dir) 92 | manifest = Path(manifest) 93 | suffix = "".join(manifest.suffixes) 94 | cuts = load_manifest_lazy_or_eager(manifest) 95 | 96 | # sort cuts by recording ID and speaker ID 97 | cuts = CutSet.from_cuts( 98 | sorted(cuts, key=lambda c: (c.recording_id, c.supervisions[0].speaker)) 99 | ) 100 | parts = cuts.split(num_splits=num_splits, shuffle=False) 101 | output_dir.mkdir(parents=True, exist_ok=True) 102 | for idx, part in enumerate(parts): 103 | part.to_file((output_dir / manifest.stem).with_suffix(f".{str(idx+1)}{suffix}")) 104 | -------------------------------------------------------------------------------- /gss/cacgmm/cacg.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from: 2 | # https://github.com/fgnt/pb_bss/blob/master/pb_bss/distribution/complex_angular_central_gaussian.py 3 | 4 | from dataclasses import dataclass 5 | 6 | import cupy as cp 7 | 8 | from gss.cacgmm.utils import is_broadcast_compatible, normalize_observation 9 | 10 | 11 | @dataclass 12 | class ComplexAngularCentralGaussian: 13 | """ 14 | Note: 15 | Instead of the covariance the eigenvectors and eigenvalues are saved. 16 | These saves some computations, because to have a more stable covariance, 17 | the eigenvalues are floored. 18 | """ 19 | 20 | covariance_eigenvectors: cp.array = None # (..., D, D) 21 | covariance_eigenvalues: cp.array = None # (..., D) 22 | 23 | @classmethod 24 | def from_covariance( 25 | cls, 26 | covariance, 27 | eigenvalue_floor=0.0, 28 | covariance_norm="eigenvalue", 29 | ): 30 | if covariance_norm == "trace": 31 | cov_trace = cp.einsum("...dd", covariance)[..., None, None] 32 | covariance /= cp.maximum(cov_trace, cp.finfo(cov_trace.dtype).tiny) 33 | else: 34 | assert covariance_norm in ["eigenvalue", False] 35 | 36 | try: 37 | eigenvals, eigenvecs = cp.linalg.eigh(covariance) 38 | except cp.linalg.LinAlgError: 39 | # ToDo: figure out when this happen and why eig may work. 40 | # It is likely that eig is more stable than eigh. 41 | try: 42 | eigenvals, eigenvecs = cp.linalg.eig(covariance) 43 | except cp.linalg.LinAlgError: 44 | if eigenvalue_floor == 0: 45 | raise RuntimeError( 46 | "When you set the eigenvalue_floor to zero it can " 47 | "happen that the eigenvalues get zero and the " 48 | "reciprocal eigenvalue that is used in " 49 | f"{cls.__name__}._log_pdf gets infinity." 50 | ) 51 | else: 52 | raise 53 | eigenvals = eigenvals.real 54 | if covariance_norm == "eigenvalue": 55 | # The scale of the eigenvals does not matter. 56 | eigenvals = eigenvals / cp.maximum( 57 | cp.amax(eigenvals, axis=-1, keepdims=True), 58 | cp.finfo(eigenvals.dtype).tiny, 59 | ) 60 | eigenvals = cp.maximum( 61 | eigenvals, 62 | eigenvalue_floor, 63 | ) 64 | else: 65 | eigenvals = cp.maximum( 66 | eigenvals, 67 | cp.amax(eigenvals, axis=-1, keepdims=True) * eigenvalue_floor, 68 | ) 69 | assert cp.isfinite(eigenvals).all(), eigenvals 70 | 71 | return cls( 72 | covariance_eigenvalues=eigenvals, 73 | covariance_eigenvectors=eigenvecs, 74 | ) 75 | 76 | @property 77 | def log_determinant(self): 78 | return cp.sum(cp.log(self.covariance_eigenvalues), axis=-1) 79 | 80 | def log_pdf(self, y): 81 | """ 82 | Args: 83 | y: Shape (..., T, D) 84 | Returns: 85 | """ 86 | y = normalize_observation(y) # swap D and T dim 87 | log_pdf, _ = self._log_pdf(y) 88 | return log_pdf 89 | 90 | def _log_pdf(self, y): 91 | """Gets used by. e.g. the cACGMM. 92 | TODO: quadratic_form might be useful by itself 93 | Note: y shape is (..., D, T) and not (..., T, D) like in log_pdf 94 | Args: 95 | y: Normalized observations with shape (..., D, T). 96 | Returns: Affiliations with shape (..., K, T) and quadratic format 97 | with the same shape. 98 | """ 99 | *independent, D, T = y.shape 100 | 101 | assert is_broadcast_compatible( 102 | [*independent, D, D], self.covariance_eigenvectors.shape 103 | ), (y.shape, self.covariance_eigenvectors.shape) 104 | 105 | einsum_path = ["einsum_path", (1, 2), (1, 3), (0, 2), (0, 1)] 106 | quadratic_form = cp.maximum( 107 | cp.abs( 108 | cp.einsum( 109 | "...dt,...de,...e,...ge,...gt->...t", 110 | y.conj(), 111 | self.covariance_eigenvectors, 112 | 1 / self.covariance_eigenvalues, 113 | self.covariance_eigenvectors.conj(), 114 | y, 115 | optimize=einsum_path, 116 | ) 117 | ), 118 | cp.finfo(y.dtype).tiny, 119 | ) 120 | log_pdf = -D * cp.log(quadratic_form) 121 | log_pdf -= self.log_determinant[..., None] 122 | 123 | return log_pdf, quadratic_form 124 | -------------------------------------------------------------------------------- /gss/cacgmm/cacg_trainer.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from: 2 | # https://github.com/fgnt/pb_bss/blob/master/pb_bss/distribution/complex_angular_central_gaussian.py 3 | 4 | import cupy as cp 5 | 6 | from gss.cacgmm.cacg import ComplexAngularCentralGaussian 7 | from gss.cacgmm.utils import ( 8 | force_hermitian, 9 | is_broadcast_compatible, 10 | normalize_observation, 11 | ) 12 | 13 | 14 | class ComplexAngularCentralGaussianTrainer: 15 | def fit( 16 | self, 17 | y, 18 | saliency=None, 19 | hermitize=True, 20 | covariance_norm="eigenvalue", 21 | eigenvalue_floor=1e-10, 22 | iterations=10, 23 | ): 24 | """ 25 | Args: 26 | y: Should be normalized to unit norm. We normalize it anyway again. 27 | Shape (..., D, T), e.g. (1, D, T) for mixture models 28 | saliency: Shape (..., T), e.g. (K, T) for mixture models 29 | hermitize: 30 | eigenvalue_floor: 31 | iterations: 32 | Returns: 33 | """ 34 | *independent, T, D = y.shape 35 | assert cp.iscomplexobj(y), y.dtype 36 | assert y.shape[-1] > 1 37 | y = normalize_observation(y) # swap D and T dim 38 | 39 | if saliency is None: 40 | quadratic_form = cp.ones(*independent, T) 41 | else: 42 | raise NotImplementedError 43 | 44 | assert iterations > 0, iterations 45 | for _ in range(iterations): 46 | model = self._fit( 47 | y=y, 48 | saliency=saliency, 49 | quadratic_form=quadratic_form, 50 | hermitize=hermitize, 51 | covariance_norm=covariance_norm, 52 | eigenvalue_floor=eigenvalue_floor, 53 | ) 54 | _, quadratic_form = model._log_pdf(y) 55 | 56 | return model 57 | 58 | def _fit( 59 | self, 60 | y, 61 | saliency, 62 | quadratic_form, 63 | hermitize=True, 64 | covariance_norm="eigenvalue", 65 | eigenvalue_floor=1e-10, 66 | ) -> ComplexAngularCentralGaussian: 67 | """Single step of the fit function. In general, needs iterations. 68 | Note: y shape is (..., D, T) and not (..., T, D) like in fit 69 | Args: 70 | y: Assumed to have unit length. 71 | Shape (..., D, T), e.g. (1, D, T) for mixture models 72 | saliency: Shape (..., T), e.g. (K, T) for mixture models 73 | quadratic_form: (..., T), e.g. (K, T) for mixture models 74 | hermitize: 75 | eigenvalue_floor: 76 | """ 77 | assert cp.iscomplexobj(y), y.dtype 78 | 79 | assert is_broadcast_compatible(y.shape[:-2], quadratic_form.shape[:-1]), ( 80 | y.shape, 81 | quadratic_form.shape, 82 | ) 83 | 84 | D = y.shape[-2] 85 | *independent, T = quadratic_form.shape 86 | 87 | if saliency is None: 88 | saliency = 1 89 | denominator = cp.array(T, dtype=cp.float64) 90 | else: 91 | assert y.ndim == saliency.ndim + 1, (y.shape, saliency.ndim) 92 | denominator = cp.einsum("...n->...", saliency)[..., None, None] 93 | 94 | # Set 0 values in denominator to small epsilon to avoid division by 0 95 | cp.clip( 96 | denominator, 97 | a_min=cp.finfo(denominator.dtype).tiny, 98 | a_max=None, 99 | out=denominator, 100 | ) 101 | 102 | # When the covariance matrix is zero, quadratic_form would also zero. 103 | # quadratic_form have to be positive 104 | cp.clip( 105 | quadratic_form, 106 | # Use 2 * tiny, because tiny is to small 107 | a_min=10 * cp.finfo(quadratic_form.dtype).tiny, 108 | a_max=None, 109 | out=quadratic_form, 110 | ) 111 | 112 | einsum_path = ["einsum_path", (0, 2), (0, 1)] 113 | covariance = D * cp.einsum( 114 | "...dn,...Dn,...n->...dD", 115 | y, 116 | y.conj(), 117 | (saliency / quadratic_form), 118 | optimize=einsum_path, 119 | ) 120 | assert cp.isfinite(quadratic_form).all() 121 | covariance /= denominator 122 | assert covariance.shape == (*independent, D, D), ( 123 | covariance.shape, 124 | (*independent, D, D), 125 | ) 126 | 127 | assert cp.isfinite(covariance).all() 128 | 129 | if hermitize: 130 | covariance = force_hermitian(covariance) 131 | 132 | return ComplexAngularCentralGaussian.from_covariance( 133 | covariance, 134 | eigenvalue_floor=eigenvalue_floor, 135 | covariance_norm=covariance_norm, 136 | ) 137 | -------------------------------------------------------------------------------- /gss/utils/numpy_utils.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | import numpy as np 3 | 4 | 5 | def segment_axis( 6 | x, 7 | length: int, 8 | shift: int, 9 | axis: int = -1, 10 | end="pad", 11 | pad_mode="constant", 12 | pad_value=0, 13 | ): 14 | """!!! WIP !!! 15 | 16 | ToDo: Discuss: Outsource conv_pad? 17 | 18 | Generate a new array that chops the given array along the given axis 19 | into overlapping frames. 20 | 21 | Note: if end='pad' the return is maybe a copy 22 | 23 | :param x: The array to segment 24 | :param length: The length of each frame 25 | :param shift: The number of array elements by which the frames should shift 26 | Negative values are also allowed. 27 | :param axis: The axis to operate on 28 | :param end: 29 | 'pad' -> pad, 30 | pad the last block with zeros if necessary 31 | None -> assert, 32 | assume the length match, ensures a no copy 33 | 'cut' -> cut, 34 | remove the last block if there are not enough values 35 | 'conv_pad' 36 | special padding for convolution, assumes shift == 1, see example 37 | below 38 | 39 | :param pad_mode: see numpy.pad 40 | :param pad_value: The value to pad 41 | :return: 42 | 43 | """ 44 | xp = cp.get_array_module(x) 45 | 46 | axis = axis % x.ndim 47 | 48 | # Implement negative shift with a positive shift and a flip 49 | # stride_tricks does not work correct with negative stride 50 | if shift > 0: 51 | do_flip = False 52 | elif shift < 0: 53 | do_flip = True 54 | shift = abs(shift) 55 | else: 56 | raise ValueError(shift) 57 | 58 | if pad_mode == "constant": 59 | pad_kwargs = {"constant_values": pad_value} 60 | else: 61 | pad_kwargs = {} 62 | 63 | # Pad 64 | if end == "pad": 65 | if x.shape[axis] < length: 66 | npad = np.zeros([x.ndim, 2], dtype=xp.int) 67 | npad[axis, 1] = length - x.shape[axis] 68 | x = xp.pad(x, pad_width=npad, mode=pad_mode, **pad_kwargs) 69 | elif shift != 1 and (x.shape[axis] + shift - length) % shift != 0: 70 | npad = np.zeros([x.ndim, 2], dtype=int) 71 | npad[axis, 1] = shift - ((x.shape[axis] + shift - length) % shift) 72 | x = xp.pad(x, pad_width=npad, mode=pad_mode, **pad_kwargs) 73 | 74 | elif end == "conv_pad": 75 | assert shift == 1, shift 76 | npad = np.zeros([x.ndim, 2], dtype=int) 77 | npad[axis, :] = length - shift 78 | x = xp.pad(x, pad_width=npad, mode=pad_mode, **pad_kwargs) 79 | elif end is None: 80 | assert ( 81 | x.shape[axis] + shift - length 82 | ) % shift == 0, "{} = x.shape[axis]({}) + shift({}) - length({})) % shift({})" "".format( 83 | (x.shape[axis] + shift - length) % shift, 84 | x.shape[axis], 85 | shift, 86 | length, 87 | shift, 88 | ) 89 | elif end == "cut": 90 | pass 91 | else: 92 | raise ValueError(end) 93 | 94 | # Calculate desired shape and strides 95 | shape = list(x.shape) 96 | # assert shape[axis] >= length, shape 97 | del shape[axis] 98 | shape.insert(axis, (x.shape[axis] + shift - length) // shift) 99 | shape.insert(axis + 1, length) 100 | 101 | strides = list(x.strides) 102 | strides.insert(axis, shift * strides[axis]) 103 | 104 | try: 105 | x = xp.lib.stride_tricks.as_strided(x, strides=strides, shape=shape) 106 | 107 | except Exception: 108 | print("strides:", x.strides, " -> ", strides) 109 | print("shape:", x.shape, " -> ", shape) 110 | print("flags:", x.flags) 111 | print("Parameters:") 112 | print( 113 | "shift:", 114 | shift, 115 | "Note: negative shift is implemented with a " "following flip", 116 | ) 117 | print("length:", length, "<- Has to be positive.") 118 | raise 119 | if do_flip: 120 | return xp.flip(x, axis=axis) 121 | else: 122 | return x 123 | 124 | 125 | # http://stackoverflow.com/a/3153267 126 | def roll_zeropad(a, shift, axis=None): 127 | """ 128 | Roll array elements along a given axis. 129 | 130 | Elements off the end of the array are treated as zeros. 131 | 132 | Parameters 133 | ---------- 134 | a : array_like 135 | Input array. 136 | shift : int 137 | The number of places by which elements are shifted. 138 | axis : int, optional 139 | The axis along which elements are shifted. By default, the array 140 | is flattened before shifting, after which the original 141 | shape is restored. 142 | 143 | Returns 144 | ------- 145 | res : ndarray 146 | Output array, with the same shape as `a`. 147 | 148 | See Also 149 | -------- 150 | roll : Elements that roll off one end come back on the other. 151 | rollaxis : Roll the specified axis backwards, until it lies in a 152 | given position. 153 | 154 | """ 155 | if a.__class__.__module__ == "cupy.core.core": 156 | import cupy 157 | 158 | xp = cupy 159 | else: 160 | xp = np 161 | 162 | if shift == 0: 163 | return a 164 | if axis is None: 165 | n = a.size 166 | reshape = True 167 | else: 168 | n = a.shape[axis] 169 | reshape = False 170 | if xp.abs(shift) > n: 171 | res = xp.zeros_like(a) 172 | elif shift < 0: 173 | shift += n 174 | zeros = xp.zeros_like(a.take(xp.arange(n - shift), axis)) 175 | res = xp.concatenate((a.take(xp.arange(n - shift, n), axis), zeros), axis) 176 | else: 177 | zeros = xp.zeros_like(a.take(xp.arange(n - shift, n), axis)) 178 | res = xp.concatenate((zeros, a.take(xp.arange(n - shift), axis)), axis) 179 | if reshape: 180 | return res.reshape(a.shape) 181 | else: 182 | return res 183 | -------------------------------------------------------------------------------- /gss/beamformer/utils.py: -------------------------------------------------------------------------------- 1 | # The functions in this file are modified from: 2 | # https://github.com/fgnt/pb_chime5/blob/master/pb_chime5/utils/numpy_utils.py 3 | 4 | import re 5 | 6 | import cupy as cp 7 | from numpy.core.einsumfunc import _parse_einsum_input 8 | 9 | 10 | def _normalize(op): 11 | op = op.replace(",", "") 12 | op = op.replace(" ", "") 13 | op = " ".join(c for c in op) 14 | op = op.replace(" * ", "*") 15 | op = op.replace("- >", "->") 16 | op = op.replace(". . .", "...") 17 | return op 18 | 19 | 20 | def _shrinking_reshape(array, source, target): 21 | source, target = source.split(), target.replace(" * ", "*").split() 22 | 23 | if "..." in source: 24 | assert "..." in target, (source, target) 25 | independent_dims = array.ndim - len(source) + 1 26 | import string 27 | 28 | ascii_letters = [ 29 | s for s in string.ascii_letters if s not in source and s not in target 30 | ] 31 | index = source.index("...") 32 | source[index : index + 1] = ascii_letters[:independent_dims] 33 | index = target.index("...") 34 | target[index : index + 1] = ascii_letters[:independent_dims] 35 | 36 | input_shape = {key: array.shape[index] for index, key in enumerate(source)} 37 | 38 | output_shape = [] 39 | for t in target: 40 | product = 1 41 | if not t == "1": 42 | t = t.split("*") 43 | for t_ in t: 44 | product *= input_shape[t_] 45 | output_shape.append(product) 46 | 47 | return array.reshape(output_shape) 48 | 49 | 50 | def _expanding_reshape(array, source, target, **shape_hints): 51 | try: # Check number of inputs for unflatten operations 52 | assert len(re.sub(r".\*", "", source.replace(" ", ""))) == array.ndim, ( 53 | array.shape, 54 | source, 55 | target, 56 | ) 57 | except AssertionError: # Check number of inputs for ellipses operations 58 | assert ( 59 | len(re.sub(r"(\.\.\.)|(.\*)", "", source.replace(" ", ""))) <= array.ndim 60 | ), (array.shape, source, target) 61 | 62 | def _get_source_grouping(source): 63 | """ 64 | Gets axis as alphanumeric. 65 | """ 66 | 67 | source = " ".join(source) 68 | source = source.replace(" * ", "*") 69 | groups = source.split() 70 | groups = [group.split("*") for group in groups] 71 | return groups 72 | 73 | if "*" not in source: 74 | return array 75 | 76 | source, target = source.split(), target.replace(" * ", "*").split() 77 | 78 | if "..." in source: 79 | assert "..." in target, (source, target) 80 | independent_dims = array.ndim - len(source) + 1 81 | import string 82 | 83 | ascii_letters = [ 84 | s for s in string.ascii_letters if s not in source and s not in target 85 | ] 86 | index = source.index("...") 87 | source[index : index + 1] = ascii_letters[:independent_dims] 88 | index = target.index("...") 89 | target[index : index + 1] = ascii_letters[:independent_dims] 90 | 91 | target_shape = [] 92 | 93 | for axis, group in enumerate(_get_source_grouping(source)): 94 | if len(group) == 1: 95 | target_shape.append(array.shape[axis : axis + 1]) 96 | else: 97 | shape_wildcard_remaining = True 98 | for member in group: 99 | if member in shape_hints: 100 | target_shape.append([shape_hints[member]]) 101 | else: 102 | if shape_wildcard_remaining: 103 | shape_wildcard_remaining = False 104 | target_shape.append([-1]) 105 | else: 106 | raise ValueError("Not enough shape hints provided.") 107 | 108 | target_shape = cp.concatenate(target_shape, 0) 109 | array = array.reshape(target_shape) 110 | return array 111 | 112 | 113 | def morph(operation, array, reduce=None, **shape_hints): 114 | """This is an experimental version of a generalized reshape. 115 | See test cases for examples. 116 | """ 117 | operation = _normalize(operation) 118 | source, target = operation.split("->") 119 | 120 | # Expanding reshape 121 | array = _expanding_reshape(array, source, target, **shape_hints) 122 | 123 | # Initial squeeze 124 | squeeze_operation = operation.split("->")[0].split() 125 | for axis, op in reversed(list(enumerate(squeeze_operation))): 126 | if op == "1": 127 | array = cp.squeeze(array, axis=axis) 128 | 129 | # Transpose 130 | transposition_operation = operation.replace("1", " ").replace("*", " ") 131 | try: 132 | in_shape, out_shape, (array,) = _parse_einsum_input( 133 | [transposition_operation.replace(" ", ""), array.get()] 134 | ) 135 | 136 | if len(set(in_shape) - set(out_shape)) > 0: 137 | assert reduce is not None, ( 138 | "Missing reduce function", 139 | reduce, 140 | transposition_operation, 141 | ) 142 | 143 | reduce_axis = tuple( 144 | [i for i, s in enumerate(in_shape) if s not in out_shape] 145 | ) 146 | array = reduce(array, axis=reduce_axis) 147 | in_shape = "".join([s for s in in_shape if s in out_shape]) 148 | 149 | array = cp.einsum(f"{in_shape}->{out_shape}", array) 150 | except ValueError as e: 151 | msg = ( 152 | f"op: {transposition_operation} ({in_shape}->{out_shape}), " 153 | f"shape: {cp.shape(array)}" 154 | ) 155 | 156 | if len(e.args) == 1: 157 | e.args = (e.args[0] + "\n\n" + msg,) 158 | else: 159 | print(msg) 160 | raise 161 | 162 | # Final reshape 163 | source = transposition_operation.split("->")[-1] 164 | target = operation.split("->")[-1] 165 | 166 | return _shrinking_reshape(array, source, target) 167 | -------------------------------------------------------------------------------- /gss/wpe/wpe.py: -------------------------------------------------------------------------------- 1 | # The functions in this module are modified from: 2 | # https://github.com/fgnt/nara_wpe/blob/master/nara_wpe/wpe.py 3 | 4 | import functools 5 | import operator 6 | 7 | import cupy as cp 8 | import numpy as np 9 | 10 | from gss.utils.numpy_utils import segment_axis 11 | 12 | 13 | def get_working_shape(shape): 14 | "Flattens all but the last two dimension." 15 | product = functools.reduce(operator.mul, [1] + list(shape[:-2])) 16 | return [product] + list(shape[-2:]) 17 | 18 | 19 | def _stable_solve(A, B): 20 | assert A.shape[:-2] == B.shape[:-2], (A.shape, B.shape) 21 | assert A.shape[-1] == B.shape[-2], (A.shape, B.shape) 22 | try: 23 | return cp.linalg.solve(A, B) 24 | except: 25 | shape_A, shape_B = A.shape, B.shape 26 | assert shape_A[:-2] == shape_A[:-2] 27 | working_shape_A = get_working_shape(shape_A) 28 | working_shape_B = get_working_shape(shape_B) 29 | A = A.reshape(working_shape_A) 30 | B = B.reshape(working_shape_B) 31 | 32 | C = cp.zeros_like(B) 33 | for i in range(working_shape_A[0]): 34 | # lstsq is much slower, use it only when necessary 35 | try: 36 | C[i] = cp.linalg.solve(A[i], B[i]) 37 | except cp.linalg.linalg.LinAlgError: 38 | C[i] = cp.linalg.lstsq(A[i], B[i])[0] 39 | return C.reshape(*shape_B) 40 | 41 | 42 | def build_y_tilde(Y, taps, delay): 43 | S = Y.shape[:-2] 44 | D = Y.shape[-2] 45 | T = Y.shape[-1] 46 | 47 | def pad(x, axis=-1, pad_width=taps + delay - 1): 48 | npad = np.zeros([x.ndim, 2], dtype=int) 49 | npad[axis, 0] = pad_width 50 | x = cp.pad(x, pad_width=npad, mode="constant", constant_values=0) 51 | return x 52 | 53 | Y_ = pad(Y) 54 | Y_ = cp.moveaxis(Y_, -1, -2) 55 | Y_ = cp.flip(Y_, axis=-1) 56 | Y_ = cp.ascontiguousarray(Y_) 57 | Y_ = cp.flip(Y_, axis=-1) 58 | Y_ = segment_axis(Y_, taps, 1, axis=-2) 59 | Y_ = cp.flip(Y_, axis=-2) 60 | if delay > 0: 61 | Y_ = Y_[..., :-delay, :, :] 62 | Y_ = cp.reshape(Y_, list(S) + [T, taps * D]) 63 | Y_ = cp.moveaxis(Y_, -2, -1) 64 | 65 | return Y_ 66 | 67 | 68 | def hermite(x): 69 | return x.swapaxes(-2, -1).conj() 70 | 71 | 72 | def get_power_inverse(signal, psd_context=0): 73 | power = cp.mean(abs_square(signal), axis=-2) 74 | 75 | if np.isposinf(psd_context): 76 | power = cp.broadcast_to(cp.mean(power, axis=-1, keepdims=True), power.shape) 77 | elif psd_context > 0: 78 | assert int(psd_context) == psd_context, psd_context 79 | psd_context = int(psd_context) 80 | power = window_mean(power, (psd_context, psd_context)) 81 | elif psd_context == 0: 82 | pass 83 | else: 84 | raise ValueError(psd_context) 85 | return _stable_positive_inverse(power) 86 | 87 | 88 | def abs_square(x): 89 | if cp.iscomplexobj(x): 90 | return x.real**2 + x.imag**2 91 | else: 92 | return x**2 93 | 94 | 95 | def window_mean(x, lr_context, axis=-1): 96 | if isinstance(lr_context, int): 97 | lr_context = [lr_context + 1, lr_context] 98 | else: 99 | assert len(lr_context) == 2, lr_context 100 | tmp_l_context, tmp_r_context = lr_context 101 | lr_context = tmp_l_context + 1, tmp_r_context 102 | 103 | x = cp.asarray(x) 104 | 105 | window_length = sum(lr_context) 106 | if window_length == 0: 107 | return x 108 | 109 | pad_width = np.zeros((x.ndim, 2), dtype=np.int64) 110 | pad_width[axis] = lr_context 111 | 112 | first_slice = [slice(None)] * x.ndim 113 | first_slice[axis] = slice(sum(lr_context), None) 114 | second_slice = [slice(None)] * x.ndim 115 | second_slice[axis] = slice(None, -sum(lr_context)) 116 | 117 | def foo(x): 118 | cumsum = cp.cumsum(cp.pad(x, pad_width, mode="constant"), axis=axis) 119 | return cumsum[first_slice] - cumsum[second_slice] 120 | 121 | ones_shape = [1] * x.ndim 122 | ones_shape[axis] = x.shape[axis] 123 | 124 | return foo(x) / foo(cp.ones(ones_shape, cp.int64)) 125 | 126 | 127 | def _stable_positive_inverse(power): 128 | eps = 1e-10 * cp.max(power) 129 | if eps == 0: 130 | # Special case when signal is zero. 131 | # Does not happen on real data. 132 | # This only happens in artificial cases, e.g. redacted signal parts, 133 | # where the signal is set to be zero from a human. 134 | # 135 | # The scale of the power does not matter, so take 1. 136 | inverse_power = cp.ones_like(power) 137 | else: 138 | cp.clip(power, a_min=eps, a_max=None, out=power) 139 | inverse_power = 1 / power 140 | return inverse_power 141 | 142 | 143 | def wpe(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full"): 144 | """ 145 | Batched WPE implementation (same as wpe_v6 in nara_wpe) 146 | 147 | Applicable in for-loops. 148 | 149 | Args: 150 | Y: Complex valued STFT signal with shape (F, D, T). 151 | taps: Filter order 152 | delay: Delay as a guard interval, such that X does not become zero. 153 | iterations: 154 | psd_context: Defines the number of elements in the time window 155 | to improve the power estimation. Total number of elements will 156 | be (psd_context + 1 + psd_context). 157 | statistics_mode: Either 'full' or 'valid'. 158 | 'full': Pad the observation with zeros on the left for the 159 | estimation of the correlation matrix and vector. 160 | 'valid': Only calculate correlation matrix and vector on valid 161 | slices of the observation. 162 | 163 | Returns: 164 | Estimated signal with the same shape as Y 165 | 166 | """ 167 | 168 | if statistics_mode == "full": 169 | s = Ellipsis 170 | elif statistics_mode == "valid": 171 | s = (Ellipsis, slice(delay + taps - 1, None)) 172 | else: 173 | raise ValueError(statistics_mode) 174 | 175 | X = cp.copy(Y) 176 | Y_tilde = build_y_tilde(Y, taps, delay) 177 | for iteration in range(iterations): 178 | inverse_power = get_power_inverse(X, psd_context=psd_context) 179 | Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] 180 | R = cp.matmul(Y_tilde_inverse_power[s], hermite(Y_tilde[s])) 181 | P = cp.matmul(Y_tilde_inverse_power[s], hermite(Y[s])) 182 | G = _stable_solve(R, P) 183 | X = Y - cp.matmul(hermite(G), Y_tilde) 184 | 185 | return X 186 | -------------------------------------------------------------------------------- /gss/core/stft_module.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from: 2 | # https://github.com/fgnt/paderbox/blob/master/paderbox/transform/module_stft.py 3 | import string 4 | import typing 5 | from math import ceil 6 | 7 | import cupy as cp 8 | import cupyx as cpx 9 | import numpy as np 10 | from cupy.fft import irfft, rfft 11 | 12 | from gss.utils.numpy_utils import roll_zeropad, segment_axis 13 | 14 | 15 | def stft( 16 | time_signal, 17 | size: int = 1024, 18 | shift: int = 256, 19 | *, 20 | axis=-1, 21 | fading=True, 22 | ) -> cp.array: 23 | """ 24 | Calculates the short time Fourier transformation of a multi channel multi 25 | speaker time signal. It is able to add additional zeros for fade-in and 26 | fade out and should yield an STFT signal which allows perfect 27 | reconstruction. 28 | 29 | :param time_signal: Multi channel time signal with dimensions 30 | AA x ... x AZ x T x BA x ... x BZ. 31 | :param size: Scalar FFT-size. 32 | :param shift: Scalar FFT-shift, the step between successive frames in 33 | samples. Typically shift is a fraction of size. 34 | :param axis: Scalar axis of time. 35 | Default: None means the biggest dimension. 36 | :param fading: Pads the signal with zeros for better reconstruction. 37 | :return: Single channel complex STFT signal with dimensions 38 | AA x ... x AZ x T' times size/2+1 times BA x ... x BZ. 39 | """ 40 | ndim = time_signal.ndim 41 | axis = axis % ndim 42 | 43 | window_length = size 44 | 45 | # Pad with zeros to have enough samples for the window function to fade. 46 | assert fading in [None, True, False, "full", "half"], fading 47 | if fading not in [False, None]: 48 | pad_width = np.zeros([ndim, 2], dtype=int) 49 | if fading == "half": 50 | pad_width[axis, 0] = (window_length - shift) // 2 51 | pad_width[axis, 1] = ceil((window_length - shift) / 2) 52 | else: 53 | pad_width[axis, :] = window_length - shift 54 | time_signal = cp.pad(time_signal, pad_width, mode="constant") 55 | 56 | window = cp.blackman(window_length + 1)[:-1] 57 | 58 | time_signal_seg = segment_axis( 59 | time_signal, window_length, shift=shift, axis=axis, end="pad" 60 | ) 61 | 62 | letters = string.ascii_lowercase[: time_signal_seg.ndim] 63 | mapping = letters + "," + letters[axis + 1] + "->" + letters 64 | 65 | try: 66 | return rfft( 67 | cp.einsum(mapping, time_signal_seg, window), 68 | n=size, 69 | axis=axis + 1, 70 | ) 71 | except ValueError as e: 72 | raise ValueError( 73 | f"Could not calculate the stft, something does not match.\n" 74 | f"mapping: {mapping}, " 75 | f"time_signal_seg.shape: {time_signal_seg.shape}, " 76 | f"window.shape: {window.shape}, " 77 | f"size: {size}" 78 | f"axis+1: {axis+1}" 79 | ) from e 80 | 81 | 82 | def _biorthogonal_window_brute_force(analysis_window, shift, use_amplitude=False): 83 | """ 84 | The biorthogonal window (synthesis_window) must verify the criterion: 85 | synthesis_window * analysis_window plus it's shifts must be one. 86 | 1 == sum m from -inf to inf over (synthesis_window(n - mB) * analysis_window(n - mB)) 87 | B ... shift 88 | n ... time index 89 | m ... shift index 90 | 91 | :param analysis_window: 92 | :param shift: 93 | :return: 94 | 95 | >>> analysis_window = signal.windows.blackman(4+1)[:-1] 96 | >>> print(analysis_window) 97 | [-1.38777878e-17 3.40000000e-01 1.00000000e+00 3.40000000e-01] 98 | >>> synthesis_window = _biorthogonal_window_brute_force(analysis_window, 1) 99 | >>> print(synthesis_window) 100 | [-1.12717575e-17 2.76153346e-01 8.12215724e-01 2.76153346e-01] 101 | >>> mult = analysis_window * synthesis_window 102 | >>> sum(mult) 103 | 1.0000000000000002 104 | """ 105 | size = len(analysis_window) 106 | 107 | influence_width = (size - 1) // shift 108 | 109 | denominator = cp.zeros_like(analysis_window) 110 | 111 | if use_amplitude: 112 | analysis_window_square = analysis_window 113 | else: 114 | analysis_window_square = analysis_window**2 115 | for i in range(-influence_width, influence_width + 1): 116 | denominator += roll_zeropad(analysis_window_square, shift * i) 117 | 118 | if use_amplitude: 119 | synthesis_window = 1 / denominator 120 | else: 121 | synthesis_window = analysis_window / denominator 122 | return synthesis_window 123 | 124 | 125 | def istft( 126 | stft_signal, 127 | size: int = 1024, 128 | shift: int = 256, 129 | *, 130 | fading: typing.Optional[typing.Union[bool, str]] = "full", 131 | ): 132 | """ 133 | Calculated the inverse short time Fourier transform to exactly reconstruct 134 | the time signal. 135 | 136 | ..note:: 137 | Be careful if you make modifications in the frequency domain (e.g. 138 | beamforming) because the synthesis window is calculated according to 139 | the unmodified! analysis window. 140 | 141 | :param stft_signal: Single channel complex STFT signal 142 | with dimensions (..., frames, size/2+1). 143 | :param size: Scalar FFT-size. 144 | :param shift: Scalar FFT-shift. Typically shift is a fraction of size. 145 | :param fading: Removes the additional padding, if done during STFT. 146 | 147 | :return: Single channel complex STFT signal 148 | :return: Single channel time signal. 149 | """ 150 | assert stft_signal.shape[-1] == size // 2 + 1, str(stft_signal.shape) 151 | 152 | window_length = size 153 | 154 | window = cp.blackman(window_length + 1)[:-1] 155 | window = _biorthogonal_window_brute_force(window, shift) 156 | 157 | # In the following, we use numpy.add.at since cupyx.scatter_add does not seem to be 158 | # giving the same results. We should replace this with cupy.add.at once it is 159 | # available in the stable release (see: https://github.com/cupy/cupy/pull/7077). 160 | 161 | time_signal = np.zeros( 162 | (*stft_signal.shape[:-2], stft_signal.shape[-2] * shift + window_length - shift) 163 | ) 164 | 165 | # Get the correct view to time_signal 166 | time_signal_seg = segment_axis(time_signal, window_length, shift, end=None) 167 | 168 | np.add.at( 169 | time_signal_seg, 170 | ..., 171 | (window * cp.real(irfft(stft_signal, n=size))[..., :window_length]).get(), 172 | ) 173 | # The [..., :window_length] is the inverse of the window padding in rfft. 174 | 175 | # Compensate fade-in and fade-out 176 | 177 | assert fading in [None, True, False, "full", "half"], fading 178 | if fading not in [None, False]: 179 | pad_width = window_length - shift 180 | if fading == "half": 181 | pad_width /= 2 182 | time_signal = time_signal[ 183 | ..., int(pad_width) : time_signal.shape[-1] - ceil(pad_width) 184 | ] 185 | 186 | return time_signal 187 | -------------------------------------------------------------------------------- /gss/beamformer/beamform.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from the corresponding implementations in: 2 | # https://github.com/fgnt/pb_bss/blob/master/pb_bss/extraction/beamformer.py 3 | 4 | import functools 5 | import operator 6 | 7 | import cupy as cp 8 | import numpy as np 9 | 10 | __all__ = [ 11 | "get_power_spectral_density_matrix", 12 | "get_mvdr_vector_souden", 13 | "blind_analytic_normalization", 14 | "apply_beamforming_vector", 15 | ] 16 | 17 | 18 | def get_power_spectral_density_matrix( 19 | observation, 20 | mask=None, 21 | sensor_dim=-2, 22 | source_dim=-2, 23 | time_dim=-1, 24 | normalize=True, 25 | ): 26 | # ensure negative dim indexes 27 | sensor_dim, source_dim, time_dim = ( 28 | d % observation.ndim - observation.ndim 29 | for d in (sensor_dim, source_dim, time_dim) 30 | ) 31 | 32 | # ensure observation shape (..., sensors, frames) 33 | obs_transpose = [ 34 | i for i in range(-observation.ndim, 0) if i not in [sensor_dim, time_dim] 35 | ] + [sensor_dim, time_dim] 36 | observation = observation.transpose(obs_transpose) 37 | 38 | if mask is None: 39 | psd = cp.einsum( 40 | "...dt,...et->...de", observation, observation.conj(), optimize="optimal" 41 | ) 42 | 43 | # normalize 44 | psd /= observation.shape[-1] 45 | 46 | else: 47 | # Unfortunately, this function changes `mask`. 48 | mask = cp.copy(mask) 49 | 50 | # normalize 51 | if mask.dtype == cp.bool_: 52 | mask = cp.asfarray(mask) 53 | 54 | if normalize: 55 | mask /= cp.maximum( 56 | cp.sum(mask, axis=time_dim, keepdims=True), 57 | 1e-10, 58 | ) 59 | 60 | if mask.ndim + 1 == observation.ndim: 61 | mask = cp.expand_dims(mask, -2) 62 | psd = cp.einsum( 63 | "...dt,...et->...de", 64 | mask * observation, 65 | observation.conj(), 66 | ) 67 | else: 68 | # ensure shape (..., sources, frames) 69 | mask_transpose = [ 70 | i 71 | for i in range(-observation.ndim, 0) 72 | if i not in [source_dim, time_dim] 73 | ] + [source_dim, time_dim] 74 | mask = mask.transpose(mask_transpose) 75 | 76 | psd = cp.einsum( 77 | "...kt,...dt,...et->...kde", 78 | mask, 79 | observation, 80 | observation.conj(), 81 | optimize="optimal", 82 | ) 83 | 84 | if source_dim < -2: 85 | # Assume PSD shape (sources, ..., sensors, sensors) is desired 86 | psd = cp.rollaxis(psd, -3, source_dim % observation.ndim) 87 | 88 | return psd 89 | 90 | 91 | def blind_analytic_normalization(vector, noise_psd_matrix): 92 | einsum_path = ["einsum_path", (0, 1), (0, 1), (0, 1)] 93 | nominator = cp.einsum( 94 | "...a,...ab,...bc,...c->...", 95 | vector.conj(), 96 | noise_psd_matrix, 97 | noise_psd_matrix, 98 | vector, 99 | optimize=einsum_path, 100 | ) 101 | nominator = cp.sqrt(nominator) 102 | 103 | einsum_path = ["einsum_path", (0, 1), (0, 1)] 104 | denominator = cp.einsum( 105 | "...a,...ab,...b->...", 106 | vector.conj(), 107 | noise_psd_matrix, 108 | vector, 109 | optimize=einsum_path, 110 | ) 111 | denominator = cp.sqrt(denominator * denominator.conj()) 112 | 113 | # We do the division in numpy since the `where` argument is not available in cupy 114 | nominator = cp.asnumpy(nominator) 115 | denominator = cp.asnumpy(denominator) 116 | normalization = np.divide( # https://stackoverflow.com/a/37977222/5766934 117 | nominator, denominator, out=np.zeros_like(nominator), where=denominator != 0 118 | ) 119 | normalization = cp.asarray(normalization) 120 | 121 | return vector * cp.abs(normalization[..., cp.newaxis]) 122 | 123 | 124 | def apply_beamforming_vector(vector, mix): 125 | assert vector.shape[-1] < 30, (vector.shape, mix.shape) 126 | return cp.einsum("...a,...at->...t", vector.conj(), mix) 127 | 128 | 129 | def get_optimal_reference_channel( 130 | w_mat, 131 | target_psd_matrix, 132 | noise_psd_matrix, 133 | eps=None, 134 | ): 135 | if w_mat.ndim != 3: 136 | raise ValueError( 137 | "Estimating the ref_channel expects currently that the input " 138 | "has 3 ndims (frequency x sensors x sensors). " 139 | "Considering an independent dim in the SNR estimate is not " 140 | "unique." 141 | ) 142 | if eps is None: 143 | eps = cp.finfo(w_mat.dtype).tiny 144 | einsum_path = ["einsum_path", (0, 1), (0, 1)] 145 | SNR = cp.einsum( 146 | "...FdR,...FdD,...FDR->...R", 147 | w_mat.conj(), 148 | target_psd_matrix, 149 | w_mat, 150 | optimize=einsum_path, 151 | ) / cp.maximum( 152 | cp.einsum( 153 | "...FdR,...FdD,...FDR->...R", 154 | w_mat.conj(), 155 | noise_psd_matrix, 156 | w_mat, 157 | optimize=einsum_path, 158 | ), 159 | eps, 160 | ) 161 | # Raises an exception when np.inf and/or np.NaN was in target_psd_matrix 162 | # or noise_psd_matrix 163 | assert cp.all(cp.isfinite(SNR)), SNR 164 | return cp.argmax(SNR.real) 165 | 166 | 167 | def stable_solve(A, B): 168 | assert A.shape[:-2] == B.shape[:-2], (A.shape, B.shape) 169 | assert A.shape[-1] == B.shape[-2], (A.shape, B.shape) 170 | try: 171 | return cp.linalg.solve(A, B) 172 | except: # noqa 173 | shape_A, shape_B = A.shape, B.shape 174 | assert shape_A[:-2] == shape_A[:-2] 175 | working_shape_A = [ 176 | functools.reduce(operator.mul, [1, *shape_A[:-2]]), 177 | *shape_A[-2:], 178 | ] 179 | working_shape_B = [ 180 | functools.reduce(operator.mul, [1, *shape_B[:-2]]), 181 | *shape_B[-2:], 182 | ] 183 | A = A.reshape(working_shape_A) 184 | B = B.reshape(working_shape_B) 185 | 186 | C = cp.zeros_like(B) 187 | for i in range(working_shape_A[0]): 188 | # lstsq is much slower, use it only when necessary 189 | try: 190 | C[i] = cp.linalg.solve(A[i], B[i]) 191 | except cp.linalg.LinAlgError: 192 | C[i], *_ = cp.linalg.lstsq(A[i], B[i]) 193 | return C.reshape(*shape_B) 194 | 195 | 196 | def get_mvdr_vector_souden( 197 | target_psd_matrix, 198 | noise_psd_matrix, 199 | ref_channel=None, 200 | eps=None, 201 | ): 202 | assert noise_psd_matrix is not None 203 | 204 | phi = stable_solve(noise_psd_matrix, target_psd_matrix) 205 | lambda_ = cp.trace(phi, axis1=-1, axis2=-2)[..., None, None] 206 | if eps is None: 207 | eps = cp.finfo(lambda_.dtype).tiny 208 | mat = phi / cp.maximum(lambda_.real, eps) 209 | 210 | if ref_channel is None: 211 | ref_channel = get_optimal_reference_channel( 212 | mat, target_psd_matrix, noise_psd_matrix, eps=eps 213 | ) 214 | 215 | beamformer = mat[..., ref_channel] 216 | return beamformer 217 | -------------------------------------------------------------------------------- /gss/cacgmm/cacgmm_trainer.py: -------------------------------------------------------------------------------- 1 | # The functions here are modified from: 2 | # https://github.com/fgnt/pb_bss/blob/master/pb_bss/distribution/cacgmm.py 3 | 4 | from operator import xor 5 | 6 | import cupy as cp 7 | 8 | from gss.cacgmm.cacg_trainer import ComplexAngularCentralGaussianTrainer 9 | from gss.cacgmm.cacgmm import CACGMM 10 | from gss.cacgmm.utils import estimate_mixture_weight, normalize_observation 11 | 12 | 13 | class CACGMMTrainer: 14 | def fit( 15 | self, 16 | y, 17 | initialization=None, 18 | num_classes=None, 19 | iterations=100, 20 | saliency=None, 21 | *, 22 | source_activity_mask=None, 23 | weight_constant_axis=(-1,), 24 | hermitize=True, 25 | covariance_norm="eigenvalue", 26 | affiliation_eps=1e-10, 27 | eigenvalue_floor=1e-10, 28 | ): 29 | """ 30 | 31 | Args: 32 | y: Shape (frequency, time, channel) or (F, T, D) 33 | initialization: 34 | Affiliations between 0 and 1. Shape (F, K, T) 35 | or CACGMM instance 36 | num_classes: Scalar >0 37 | iterations: Scalar >0 38 | saliency: 39 | Importance weighting for each observation, shape (..., T) 40 | Should be pre-calculated externally, not just a string. 41 | source_activity_mask: Boolean mask that says for each time point 42 | for each source if it is active or not. 43 | Shape (F, K, T) 44 | weight_constant_axis: The axis that is used to calculate the mean 45 | over the affiliations. The affiliations have the 46 | shape (F, K, T), so the default value means averaging over 47 | the sample dimension. Note that averaging over an independent 48 | axis is supported. 49 | hermitize: 50 | covariance_norm: 'eigenvalue', 'trace' or False 51 | affiliation_eps: 52 | eigenvalue_floor: Relative flooring of the covariance eigenvalues 53 | 54 | Returns: 55 | 56 | """ 57 | assert xor(initialization is None, num_classes is None), ( 58 | "Incompatible input combination. " 59 | "Exactly one of the two inputs has to be None: " 60 | f"{initialization is None} xor {num_classes is None}" 61 | ) 62 | 63 | assert cp.iscomplexobj(y), y.dtype 64 | assert y.shape[-1] > 1, y.shape 65 | y = normalize_observation(y) # swap D and T dim, now y is F, D, T 66 | 67 | assert iterations > 0, iterations 68 | 69 | model = None 70 | 71 | *independent, D, num_observations = y.shape 72 | if initialization is None: 73 | assert num_classes is not None, num_classes 74 | affiliation_shape = (*independent, num_classes, num_observations) 75 | affiliation = cp.random.uniform(size=affiliation_shape) 76 | affiliation /= cp.einsum("fkn->fn", affiliation)[..., None, :] 77 | quadratic_form = cp.ones(affiliation_shape, dtype=y.real.dtype) 78 | elif isinstance(initialization, cp.ndarray): 79 | num_classes = initialization.shape[-2] 80 | assert num_classes > 1, num_classes 81 | affiliation_shape = (*independent, num_classes, num_observations) 82 | 83 | # Force same number of dims (Prevent wrong input) 84 | assert initialization.ndim == len(affiliation_shape), ( 85 | initialization.shape, 86 | affiliation_shape, 87 | ) 88 | 89 | # Allow singleton dimensions to be broadcasted 90 | assert initialization.shape[-2:] == affiliation_shape[-2:], ( 91 | initialization.shape, 92 | affiliation_shape, 93 | ) 94 | 95 | affiliation = cp.broadcast_to(initialization, affiliation_shape) 96 | quadratic_form = cp.ones(affiliation_shape, dtype=y.real.dtype) 97 | elif isinstance(initialization, CACGMM): 98 | # weight[-2] may be 1, when weight is fixed to 1/K 99 | # num_classes = initialization.weight.shape[-2] 100 | num_classes = initialization.cacg.covariance_eigenvectors.shape[-3] 101 | 102 | model = initialization 103 | else: 104 | raise TypeError("No sufficient initialization.") 105 | 106 | if source_activity_mask is not None: 107 | assert ( 108 | source_activity_mask.dtype == cp.bool_ 109 | ), source_activity_mask.dtype # noqa 110 | assert source_activity_mask.shape[-2:] == (num_classes, num_observations), ( 111 | source_activity_mask.shape, 112 | independent, 113 | num_classes, 114 | num_observations, 115 | ) # noqa 116 | 117 | if isinstance(initialization, cp.ndarray): 118 | assert source_activity_mask.shape == initialization.shape, ( 119 | source_activity_mask.shape, 120 | initialization.shape, 121 | ) # noqa 122 | 123 | assert num_classes < 20, f"num_classes: {num_classes}, sure?" 124 | assert D < 35, f"Channels: {D}, sure?" 125 | 126 | for iteration in range(iterations): 127 | if model is not None: 128 | affiliation, quadratic_form, _ = model._predict( 129 | y, 130 | source_activity_mask=source_activity_mask, 131 | affiliation_eps=affiliation_eps, 132 | ) 133 | 134 | model = self._m_step( 135 | y, 136 | quadratic_form, 137 | affiliation=affiliation, 138 | saliency=saliency, 139 | hermitize=hermitize, 140 | covariance_norm=covariance_norm, 141 | eigenvalue_floor=eigenvalue_floor, 142 | weight_constant_axis=weight_constant_axis, 143 | ) 144 | 145 | return model 146 | 147 | def fit_predict( 148 | self, 149 | y, 150 | initialization=None, 151 | num_classes=None, 152 | iterations=100, 153 | *, 154 | saliency=None, 155 | source_activity_mask=None, 156 | weight_constant_axis=(-1,), 157 | hermitize=True, 158 | covariance_norm="eigenvalue", 159 | affiliation_eps=1e-10, 160 | eigenvalue_floor=1e-10, 161 | ): 162 | """Fit a model. Then just return the posterior affiliations.""" 163 | model = self.fit( 164 | y=y, 165 | initialization=initialization, 166 | num_classes=num_classes, 167 | iterations=iterations, 168 | saliency=saliency, 169 | source_activity_mask=source_activity_mask, 170 | weight_constant_axis=weight_constant_axis, 171 | hermitize=hermitize, 172 | covariance_norm=covariance_norm, 173 | affiliation_eps=affiliation_eps, 174 | eigenvalue_floor=eigenvalue_floor, 175 | ) 176 | return model.predict(y) 177 | 178 | def _m_step( 179 | self, 180 | x, 181 | quadratic_form, 182 | affiliation, 183 | saliency, 184 | hermitize, 185 | covariance_norm, 186 | eigenvalue_floor, 187 | weight_constant_axis, 188 | ): 189 | weight = estimate_mixture_weight( 190 | affiliation=affiliation, 191 | saliency=saliency, 192 | weight_constant_axis=weight_constant_axis, 193 | ) 194 | 195 | if saliency is None: 196 | masked_affiliation = affiliation 197 | else: 198 | masked_affiliation = affiliation * saliency[..., None, :] 199 | 200 | cacg = ComplexAngularCentralGaussianTrainer()._fit( 201 | y=x[..., None, :, :], 202 | saliency=masked_affiliation, 203 | quadratic_form=quadratic_form, 204 | hermitize=hermitize, 205 | covariance_norm=covariance_norm, 206 | eigenvalue_floor=eigenvalue_floor, 207 | ) 208 | return CACGMM(weight=weight, cacg=cacg) 209 | -------------------------------------------------------------------------------- /gss/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple 2 | from math import ceil 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Tuple 5 | 6 | import numpy as np 7 | from cytoolz.itertoolz import groupby 8 | from lhotse import CutSet, validate 9 | from lhotse.cut import Cut, MixedCut 10 | from lhotse.dataset.sampling.dynamic import DynamicCutSampler 11 | from lhotse.dataset.sampling.dynamic_bucketing import DynamicBucketingSampler 12 | from lhotse.dataset.sampling.round_robin import RoundRobinSampler 13 | from lhotse.utils import add_durations, compute_num_samples 14 | from torch.utils.data import Dataset 15 | 16 | from gss.utils.numpy_utils import segment_axis 17 | 18 | 19 | class GssDataset(Dataset): 20 | """ 21 | It takes a batch of cuts as input (all from the same recording and speaker) and 22 | concatenates them into a single sequence. Additionally, we also extend the left 23 | and right cuts by the context duration, so that the model can see the context 24 | and disambiguate the target speaker from background noise. 25 | Returns: 26 | .. code-block:: 27 | { 28 | 'audio': (channels x total #samples) float tensor 29 | 'activity': (#speakers x total #samples) int tensor denoting speaker activities 30 | 'cuts': original cuts (sorted by start time) 31 | 'speaker': str, speaker ID 32 | 'recording': str, recording ID 33 | 'start': float tensor, start times of the cuts w.r.t. concatenated sequence 34 | } 35 | In the returned tensor, the ``audio`` and ``activity`` will be used to perform the 36 | actual enhancement. The ``speaker``, ``recording``, and ``start`` are 37 | used to name the enhanced files. 38 | """ 39 | 40 | def __init__( 41 | self, activity, context_duration: float = 0, num_channels: int = None 42 | ) -> None: 43 | super().__init__() 44 | self.activity = activity 45 | self.context_duration = context_duration 46 | self.num_channels = num_channels 47 | 48 | def __getitem__(self, cuts: CutSet) -> Dict[str, Any]: 49 | self._validate(cuts) 50 | 51 | recording_id = cuts[0].recording_id 52 | speaker = cuts[0].supervisions[0].speaker 53 | 54 | # sort cuts by start time 55 | orig_cuts = sorted(cuts, key=lambda cut: cut.start) 56 | 57 | new_cuts = orig_cuts[:] 58 | 59 | # Extend the first and last cuts by the context duration. 60 | new_cuts[0] = new_cuts[0].extend_by( 61 | duration=self.context_duration, 62 | direction="left", 63 | preserve_id=True, 64 | pad_silence=False, 65 | ) 66 | left_context = orig_cuts[0].start - new_cuts[0].start 67 | new_cuts[-1] = new_cuts[-1].extend_by( 68 | duration=self.context_duration, 69 | direction="right", 70 | preserve_id=True, 71 | pad_silence=False, 72 | ) 73 | right_context = new_cuts[-1].end - orig_cuts[-1].end 74 | 75 | concatenated = None 76 | activity = [] 77 | for new_cut in new_cuts: 78 | concatenated = ( 79 | new_cut 80 | if concatenated is None 81 | else concatenated.append(new_cut, preserve_id="left") 82 | ) 83 | cut_activity, spk_to_idx_map = self.activity.get_activity( 84 | new_cut.recording_id, new_cut.start, new_cut.duration 85 | ) 86 | activity.append(cut_activity) 87 | 88 | # Load audio 89 | audio = concatenated.load_audio() 90 | activity = np.concatenate(activity, axis=1) 91 | 92 | return { 93 | "audio": audio, 94 | "duration": add_durations( 95 | *[c.duration for c in orig_cuts], 96 | sampling_rate=concatenated.sampling_rate, 97 | ), 98 | "left_context": compute_num_samples( 99 | left_context, sampling_rate=concatenated.sampling_rate 100 | ), 101 | "right_context": compute_num_samples( 102 | right_context, sampling_rate=concatenated.sampling_rate 103 | ), 104 | "activity": activity, 105 | "orig_cuts": orig_cuts, 106 | "speaker": speaker, 107 | "speaker_idx": spk_to_idx_map[speaker], 108 | "recording_id": recording_id, 109 | } 110 | 111 | def _validate(self, cuts: CutSet) -> None: 112 | validate(cuts) 113 | assert all(cut.has_recording for cut in cuts) 114 | assert len(cuts) > 0 115 | 116 | # check that all cuts have the same speaker and recording 117 | speaker = cuts[0].supervisions[0].speaker 118 | recording = cuts[0].recording_id 119 | assert all(cut.supervisions[0].speaker == speaker for cut in cuts) 120 | assert all(cut.recording_id == recording for cut in cuts) 121 | 122 | 123 | def create_sampler( 124 | cuts: CutSet, max_duration: float = None, max_cuts: int = None, num_buckets: int = 1 125 | ) -> RoundRobinSampler: 126 | buckets = create_buckets_by_speaker(cuts) 127 | samplers = [] 128 | for bucket in buckets: 129 | num_buckets = min(num_buckets, len(frozenset(bucket.ids))) 130 | if num_buckets == 1: 131 | samplers.append( 132 | DynamicCutSampler(bucket, max_duration=max_duration, max_cuts=max_cuts) 133 | ) 134 | else: 135 | samplers.append( 136 | DynamicBucketingSampler( 137 | bucket, 138 | num_buckets=num_buckets, 139 | max_duration=max_duration, 140 | max_cuts=max_cuts, 141 | ) 142 | ) 143 | sampler = RoundRobinSampler(*samplers) 144 | return sampler 145 | 146 | 147 | def create_buckets_by_speaker(cuts: CutSet) -> List[CutSet]: 148 | """ 149 | Helper method to partition a single CutSet into buckets that have the same 150 | recording and speaker. 151 | """ 152 | buckets: Dict[Tuple[str, str], List[Cut]] = defaultdict(list) 153 | for cut in cuts: 154 | buckets[(cut.recording_id, cut.supervisions[0].speaker)].append(cut) 155 | return [CutSet.from_cuts(cuts) for cuts in buckets.values()] 156 | 157 | 158 | # Taken from: https://github.com/fgnt/nara_wpe/blob/452b95beb27afad3f8fa3e378de2803452906f1b/nara_wpe/utils.py#L203 159 | def _samples_to_stft_frames( 160 | samples, 161 | size, 162 | shift, 163 | *, 164 | pad=True, 165 | fading=False, 166 | ): 167 | """ 168 | Calculates number of STFT frames from number of samples in time domain. 169 | 170 | Args: 171 | samples: Number of samples in time domain. 172 | size: FFT size. 173 | window_length often equal to FFT size. The name size should be 174 | marked as deprecated and replaced with window_length. 175 | shift: Hop in samples. 176 | pad: See stft. 177 | fading: See stft. Note to keep old behavior, default value is False. 178 | 179 | Returns: 180 | Number of STFT frames. 181 | """ 182 | if fading: 183 | samples = samples + 2 * (size - shift) 184 | 185 | # I changed this from np.ceil to math.ceil, to yield an integer result. 186 | frames = (samples - size + shift) / shift 187 | if pad: 188 | return ceil(frames) 189 | return int(frames) 190 | 191 | 192 | def start_end_context_frames( 193 | start_context_samples, end_context_samples, stft_size, stft_shift, stft_fading 194 | ): 195 | assert start_context_samples >= 0 196 | assert end_context_samples >= 0 197 | 198 | start_context_frames = _samples_to_stft_frames( 199 | start_context_samples, 200 | size=stft_size, 201 | shift=stft_shift, 202 | fading=stft_fading, 203 | ) 204 | end_context_frames = _samples_to_stft_frames( 205 | end_context_samples, 206 | size=stft_size, 207 | shift=stft_shift, 208 | fading=stft_fading, 209 | ) 210 | return start_context_frames, end_context_frames 211 | 212 | 213 | def activity_time_to_frequency( 214 | time_activity, 215 | stft_window_length, 216 | stft_shift, 217 | stft_fading, 218 | stft_pad=True, 219 | ): 220 | assert np.asarray(time_activity).dtype != object, ( 221 | type(time_activity), 222 | np.asarray(time_activity).dtype, 223 | ) 224 | time_activity = np.asarray(time_activity) 225 | 226 | if stft_fading: 227 | pad_width = np.array([(0, 0)] * time_activity.ndim) 228 | pad_width[-1, :] = stft_window_length - stft_shift # Consider fading 229 | time_activity = np.pad(time_activity, pad_width, mode="constant") 230 | 231 | return segment_axis( 232 | time_activity, 233 | length=stft_window_length, 234 | shift=stft_shift, 235 | end="pad" if stft_pad else "cut", 236 | ).any(axis=-1) 237 | 238 | 239 | EnhancedCut = namedtuple( 240 | "EnhancedCut", ["cut", "recording_id", "speaker", "start", "end"] 241 | ) 242 | 243 | 244 | def post_process_manifests(cuts, enhanced_dir): 245 | """ 246 | Post-process the enhanced cuts to combine the ones that were created from the same 247 | segment (split due to cut_into_windows). 248 | """ 249 | enhanced_dir = Path(enhanced_dir) 250 | 251 | def _get_cut_info(cut): 252 | reco_id, spk, start_end = cut.recording_id.split("-") 253 | start, end = start_end.split("_") 254 | return reco_id, spk, float(start) / 100, float(end) / 100 255 | 256 | enhanced_cuts = [] 257 | for cut in cuts: 258 | reco_id, spk, start, end = _get_cut_info(cut) 259 | enhanced_cuts.append(EnhancedCut(cut, reco_id, spk, start, end)) 260 | 261 | # group cuts by recording id and speaker 262 | enhanced_cuts = sorted(enhanced_cuts, key=lambda x: (x.recording_id, x.speaker)) 263 | groups = groupby(lambda x: (x.recording_id, x.speaker), enhanced_cuts) 264 | 265 | combined_cuts = [] 266 | wavs_to_be_removed = [] 267 | # combine cuts that were created from the same segment 268 | for (reco_id, spk), in_cuts in groups.items(): 269 | in_cuts = sorted(in_cuts, key=lambda x: x.start) 270 | out_cut = in_cuts[0] 271 | for cut in in_cuts[1:]: 272 | if cut.start == out_cut.end: 273 | out_cut = EnhancedCut( 274 | cut=out_cut.cut.append(cut.cut), 275 | recording_id=reco_id, 276 | speaker=spk, 277 | start=out_cut.start, 278 | end=cut.end, 279 | ) 280 | # Delete the wav file of the cut that was appended (otherwise we will 281 | # have repeated audio) 282 | wavs_to_be_removed.append(cut.cut.recording.sources[0].source) 283 | else: 284 | combined_cuts.append(out_cut) 285 | out_cut = cut 286 | combined_cuts.append(out_cut) 287 | 288 | # write the combined cuts to the enhanced manifest 289 | out_cuts = [] 290 | for cut in combined_cuts: 291 | out_cut = cut.cut 292 | if isinstance(out_cut, MixedCut): 293 | out_cut = out_cut.save_audio( 294 | (enhanced_dir / cut.recording_id) 295 | / f"{cut.recording_id}-{cut.speaker}-{int(cut.start*100):06d}_{int(cut.end*100):06d}.flac" 296 | ) 297 | out_cuts.append(out_cut) 298 | 299 | # remove the wav files of the cuts that were appended 300 | for wav in wavs_to_be_removed: 301 | Path(wav).unlink() 302 | 303 | return CutSet.from_cuts(out_cuts) 304 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

GPU-accelerated Guided Source Separation

2 | 3 | Paper: https://arxiv.org/abs/2212.05271 4 | 5 | **Guided source separation** is a type of blind source separation (blind = no training required) 6 | in which the mask estimation is guided by a diarizer output. The original method was proposed 7 | for the CHiME-5 challenge in [this paper](http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_boeddecker.pdf) by Boeddeker et al. 8 | 9 | It is a kind of target-speaker extraction method. The inputs to the model are: 10 | 11 | 1. A multi-channel recording, e.g., from an array microphone, of a long, unsegmented, 12 | multi-talker session (possibly with overlapping speech) 13 | 2. An RTTM file containing speaker segment boundaries 14 | 15 | The system produces enhanced audio for each of the segments in the RTTM, removing the background 16 | speech and noise and "extracting" only the target speaker in the segment. 17 | 18 | This repository contains a GPU implementation of this method in Python, along with CLI binaries 19 | to run the enhancement from shell. We also provide several example "recipes" for using the 20 | method. 21 | 22 | ## Features 23 | 24 | The core components of the tool are borrowed from [ `pb_chime5` ](https://github.com/fgnt/pb_chime5), but GPU support is added by porting most of the work to [CuPy](https://github.com/cupy/cupy). 25 | 26 | * All the main components of the pipeline --- STFT computation, WPE, mask estimation with CACGMM, and beamforming --- 27 | are ported to CuPy to use GPUs. For CACGMM, we batch all frequency indices instead of iterating over them. 28 | * We have implemented batch processing of segments (see [this issue](https://github.com/desh2608/gss/issues/12) for details) 29 | to maximize GPU memory usage and provide additional speed-up. 30 | * The GSS implementation (see `gss/core`) has been stripped of CHiME-6 dataset-specific peculiarities 31 | (such as array naming conventions etc.) 32 | * We use Lhotse for simplified data loading, speaker activity generation, and RTTM representation. We provide 33 | examples in the `recipes` directory for how to use the `gss` module for several datasets. 34 | * The inference can be done on multi-node GPU environment. This makes it several times faster than the 35 | original CPU implementation. 36 | * We provide both Python modules and CLI for using the enhancement functions, which can be 37 | easily included in recipes from Kaldi, Icefall, ESPNet, etc. 38 | 39 | As an example, applying GSS on a LibriCSS OV20 session (~10min) took ~160s on a single RTX2080 GPU (with 12G memory). 40 | See the `test.pstats` for the profiling output. 41 | 42 | ## Installation 43 | 44 | ### Preparing to install 45 | 46 | Create a new Conda environment: 47 | 48 | ```bash 49 | conda create -n gss python=3.8 50 | ``` 51 | 52 | Install CuPy as follows (see https://docs.cupy.dev/en/stable/install.html for the appropriate version 53 | for your CUDA). 54 | 55 | ```bash 56 | pip install cupy-cuda102 57 | ``` 58 | 59 | NOTE 1: We recommend not installing the pre-release version (12.0.0rc1 at the time of writing), since there may be some issues with it. 60 | 61 | NOTE 2: if you don't have cudatoolkit 10.2 installed, you can use conda which will install it for you: 62 | 63 | ```bash 64 | conda install -c conda-forge cupy=10.2 65 | ``` 66 | 67 | ### Install (basic) 68 | 69 | ```bash 70 | pip install git+http://github.com/desh2608/gss 71 | ``` 72 | 73 | ### Install (advanced) 74 | 75 | ```bash 76 | git clone https://github.com/desh2608/gss.git & cd gss 77 | pip install -e '.[dev]' 78 | pre-commit install # installs pre-commit hooks with style checks 79 | ``` 80 | 81 | ## Usage 82 | 83 | ### Enhancing a single recording 84 | 85 | For the simple case of target-speaker extraction given a multi-channel recording and an 86 | RTTM file denoting speaker segments, run the following: 87 | 88 | ```bash 89 | export CUDA_VISIBLE_DEVICES=0 90 | gss enhance recording \ 91 | /path/to/sessionA.wav /path/to/rttm exp/enhanced_segs \ 92 | --recording-id sessionA --min-segment-length 0.1 --max-segment-length 10.0 \ 93 | --max-batch-duration 20.0 --num-buckets 2 -o exp/segments.jsonl.gz 94 | ``` 95 | 96 | ### Enhancing a corpus 97 | 98 | See the `recipes` directory for usage examples. The main stages are as follows: 99 | 100 | 1. Prepare Lhotse manifests. See [this list](https://lhotse.readthedocs.io/en/latest/corpus.html#standard-data-preparation-recipes) of corpora currently supported in Lhotse. 101 | You can also apply GSS on your own dataset by preparing it as Lhotse manifests. 102 | 103 | 2. If you are using an RTTM file to get segments (e.g. in CHiME-6 Track 2), convert the RTTMs 104 | to Lhotse-style supervision manifest. 105 | 106 | 3. Create recording-level cut sets by combining the recording with its supervisions. These 107 | will be used to get speaker activities. 108 | 109 | 4. Trim the recording-level cut set into segment-level cuts. These are the segments that will 110 | actually be enhanced. 111 | 112 | 5. (Optional) Split the segments into as many parts as the number of GPU jobs you want to run. In the 113 | recipes, we submit the jobs through `qsub` , similar to Kaldi or ESPNet recipes. You can 114 | use the parallelization in those toolkits to additionally use a different scheduler such as 115 | SLURM. 116 | 117 | 6. Run the enhancement on GPUs. The following options can be provided: 118 | 119 | * `--channels`: The channels to use for enhancement (comma-separated ints). By default, all channels are used. 120 | 121 | * `--bss-iteration`: Number of iterations of the CACGMM inference. 122 | 123 | * `--context-duration`: Context (in seconds) to include on both sides of the segment. 124 | 125 | * `--min-segment-length`: Any segment shorter than this value will be removed. This is 126 | particularly useful when using segments from a diarizer output since they often contain 127 | very small segments which are not relevant for ASR. A recommended setting is 0.1s. 128 | 129 | * `--max-segment-length`: Segments longer than this value will be chunked up. This is 130 | to prevent OOM errors since the segment STFTs are loaded onto the GPU. We use a setting 131 | of 15s in most cases. 132 | 133 | * `--max-batch-duration`: Segments from the same speaker will be batched together to increase 134 | GPU efficiency. We used 20s batches for enhancement on GPUs with 12G memory. For GPUs with 135 | larger memory, this value can be increased. 136 | 137 | * `--max-batch-cuts`: This sets an upper limit on the maximum number of cuts in a batch. To 138 | simulate segment-wise enhancement, set this to 1. 139 | 140 | * `--num-workers`: Number of workers to use for data-loading (default = 1). Use more if you 141 | increase the `max-batch-duration` . 142 | 143 | * `--num-buckets`: Number of buckets to use for sampling. Batches are drawn from the same 144 | bucket (see Lhotse's [ `DynamicBucketingSampler` ](https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/sampling/dynamic_bucketing.py) for details). 145 | 146 | * `--enhanced-manifest/-o`: Path to manifest file to write the enhanced cut manifest. This 147 | is useful for cases when the supervisions need to be propagated to the enhanced segments, 148 | for downstream ASR tasks, for example. 149 | 150 | * `--profiler-output`: Optional path to output stats file for profiling, which can be visualized 151 | using Snakeviz. 152 | 153 | * `--force-overwrite`: Flag to force enhanced audio files to be overwritten. 154 | 155 | ### Multi-GPU Usage 156 | You can refer to e.g. the [AMI recipe](./recipes/ami/run.sh) for how to use this toolkit 157 | with multiple GPUs.
158 | **NOTE**: your GPUs must be in Exclusive_Thread mode, otherwise this library may not work as expected and/or the inference 159 | time will greatly increase. **This is especially important if you are using** `run.pl`.
160 | You can check the compute mode of GPU `X` using: 161 | ```bash 162 | nvidia-smi -i X -q | grep "Compute Mode" 163 | ``` 164 | We also provide an automate tool to do that called `gpu_check` which takes as arguments the cmd used (e.g. run.pl) and number of jobs: 165 | ```bash 166 | $cmd JOB=1:$nj ${exp_dir}/${dset_name}/${dset_part}/log/enhance.JOB.log \ 167 | gss utils gpu_check $nj $cmd \& gss enhance cuts \ 168 | ${exp_dir}/${dset_name}/${dset_part}/cuts.jsonl.gz ${exp_dir}/${dset_name}/${dset_part}/split$nj/cuts_per_segment.JOB.jsonl.gz \ 169 | ${exp_dir}/${dset_name}/${dset_part}/enhanced \ 170 | --bss-iterations $gss_iterations \ 171 | --context-duration 15.0 \ 172 | --use-garbage-class \ 173 | --max-batch-duration 120 \ 174 | ${affix} || exit 1 175 | ``` 176 | See again [AMI recipe](./recipes/ami/run.sh) or the [CHiME-7 DASR GSS code](https://github.com/espnet/espnet/blob/master/egs2/chime7_task1/asr1/local/run_gss.sh). 177 | ## FAQ 178 | 179 | **What happens if I set the `--max-batch-duration` too large?** 180 | 181 | The enhancement would still work, but you will see several warnings of the sort: 182 | "Out of memory error while processing the batch. Trying again with chunks." 183 | Internally, we have a fallback option to chunk up batches into increasingly smaller 184 | parts in case OOM error is encountered (see `gss.core.enhancer.py` ). However, this 185 | would slow down processing, so we recommend reducing the batch size if you see this 186 | warning very frequently. 187 | 188 | **I am seeing "out of memory error" a lot. What should I do?** 189 | 190 | Try reducing `--max-batch-duration` . If you are enhancing a large number of very small 191 | segments, try providing `--max-batch-cuts` with some small value (e.g., 2 or 3). This 192 | is because batching together a large number of small segments requires memory 193 | overhead which can cause OOMs. 194 | 195 | **How to understand the format of output file names?** 196 | 197 | The enhanced wav files are named as *recoid-spkid-start_end.wav*, i.e., 1 wav file is 198 | generated for each segment in the RTTM. The "start" and "end" are padded to 6 digits, 199 | for example: 21.18 seconds is encoded as `002118` . This convention should be fine if 200 | your audio duration is under ~2.75 h (9999s), otherwise, you should change the 201 | padding in `gss/core/enhancer.py` . 202 | 203 | **How to solve the Lhotse AudioDurationMismatch error?** 204 | 205 | This error is raised when the audio files corresponding to different channels have 206 | different durations. This is often the case for multi-array recordings, e.g., CHiME-6. 207 | You can bypass this error by setting the `--duration-tolerance` option to some larger 208 | value (Lhotse's default is 0.025). For CHiME-6, we had to set this to 3.0. 209 | 210 | **How should I generate RTTMs required for enhancement?** 211 | 212 | For examples of how to generate RTTMs for guiding the separation, please refer to my 213 | [diarizer](https://github.com/desh2608/diarizer) toolkit. 214 | 215 | **How can I experiment with additional GSS parameters?** 216 | 217 | We have only made the most important parameters available in the 218 | top-level CLI. To play with other parameters, check out the `gss.enhancer.get_enhancer()` function. 219 | 220 | **How much speed-up can I expect to obtain?** 221 | 222 | Enhancing the CHiME-6 dev set required 1.3 hours on 4 GPUs. This is as opposed to the 223 | original implementation which required 20 hours using 80 CPU jobs. This is an effective 224 | speed-up of 292. 225 | 226 | ## Contributing 227 | 228 | Contributions for core improvements or new recipes are welcome. Please run the following 229 | before creating a pull request. 230 | 231 | ```bash 232 | pre-commit install 233 | pre-commit run # Running linter checks 234 | ``` 235 | 236 | ## Citations 237 | 238 | ``` 239 | @inproceedings{Raj2023GPUacceleratedGS, 240 | title={GPU-accelerated Guided Source Separation for Meeting Transcription}, 241 | author={Desh Raj and Daniel Povey and Sanjeev Khudanpur}, 242 | year={2023}, 243 | booktitle={InterSpeech} 244 | } 245 | ``` 246 | -------------------------------------------------------------------------------- /gss/bin/modes/enhance.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import time 4 | from pathlib import Path 5 | 6 | import click 7 | from lhotse import Recording, SupervisionSet, load_manifest_lazy 8 | from lhotse.audio import set_audio_duration_mismatch_tolerance 9 | from lhotse.cut import CutSet 10 | from lhotse.utils import fastcopy 11 | 12 | from gss.bin.modes.cli_base import cli 13 | from gss.core.enhancer import get_enhancer 14 | from gss.utils.data_utils import post_process_manifests 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 18 | datefmt="%Y-%m-%d:%H:%M:%S", 19 | level=logging.INFO, 20 | ) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | @cli.group() 25 | def enhance(): 26 | """Commands for enhancing single recordings or manifests.""" 27 | pass 28 | 29 | 30 | def common_options(func): 31 | @click.option( 32 | "--channels", 33 | "-c", 34 | type=str, 35 | default=None, 36 | help="Channels to use for enhancement. Specify with comma-separated values, e.g. " 37 | "`--channels 0,2,4`. All channels will be used by default.", 38 | ) 39 | @click.option( 40 | "--bss-iterations", 41 | "-i", 42 | type=int, 43 | default=10, 44 | help="Number of iterations for BSS", 45 | show_default=True, 46 | ) 47 | @click.option( 48 | "--use-wpe/--no-wpe", 49 | default=True, 50 | help="Whether to use WPE for GSS", 51 | show_default=True, 52 | ) 53 | @click.option( 54 | "--context-duration", 55 | type=float, 56 | default=15.0, 57 | help="Context duration in seconds for CACGMM", 58 | show_default=True, 59 | ) 60 | @click.option( 61 | "--use-garbage-class/--no-garbage-class", 62 | default=False, 63 | help="Whether to use the additional noise class for CACGMM", 64 | show_default=True, 65 | ) 66 | @click.option( 67 | "--min-segment-length", 68 | type=float, 69 | default=0.0, 70 | help="Minimum segment length to retain (removing very small segments speeds up enhancement)", 71 | show_default=True, 72 | ) 73 | @click.option( 74 | "--max-segment-length", 75 | type=float, 76 | default=15.0, 77 | help="Chunk up longer segments to avoid OOM issues", 78 | show_default=True, 79 | ) 80 | @click.option( 81 | "--max-batch-duration", 82 | type=float, 83 | default=20.0, 84 | help="Maximum duration of a batch in seconds", 85 | show_default=True, 86 | ) 87 | @click.option( 88 | "--max-batch-cuts", 89 | type=int, 90 | default=None, 91 | help="Maximum number of cuts in a batch", 92 | show_default=True, 93 | ) 94 | @click.option( 95 | "--num-workers", 96 | type=int, 97 | default=1, 98 | help="Number of workers for parallel processing", 99 | show_default=True, 100 | ) 101 | @click.option( 102 | "--num-buckets", 103 | type=int, 104 | default=2, 105 | help="Number of buckets per speaker for batching (use larger values if you set higer max-segment-length)", 106 | show_default=True, 107 | ) 108 | @click.option( 109 | "--enhanced-manifest", 110 | "-o", 111 | type=click.Path(), 112 | default=None, 113 | help="Path to the output manifest containing details of the enhanced segments.", 114 | ) 115 | @click.option( 116 | "--profiler-output", 117 | type=click.Path(), 118 | default=None, 119 | help="Path to the profiler output file.", 120 | ) 121 | @click.option( 122 | "--force-overwrite", 123 | is_flag=True, 124 | default=False, 125 | help="If set, we will overwrite the enhanced audio files if they already exist.", 126 | ) 127 | @functools.wraps(func) 128 | def wrapper(*args, **kwargs): 129 | return func(*args, **kwargs) 130 | 131 | return wrapper 132 | 133 | 134 | @enhance.command(name="cuts") 135 | @click.argument( 136 | "cuts_per_recording", 137 | type=click.Path(exists=True), 138 | ) 139 | @click.argument( 140 | "cuts_per_segment", 141 | type=click.Path(exists=True), 142 | ) 143 | @click.argument( 144 | "enhanced_dir", 145 | type=click.Path(), 146 | ) 147 | @common_options 148 | @click.option( 149 | "--duration-tolerance", 150 | type=float, 151 | default=None, 152 | help="Maximum mismatch between channel durations to allow. Some corpora like CHiME-6 " 153 | "need a large value, e.g., 2 seconds", 154 | ) 155 | def cuts_( 156 | cuts_per_recording, 157 | cuts_per_segment, 158 | enhanced_dir, 159 | channels, 160 | bss_iterations, 161 | use_wpe, 162 | context_duration, 163 | use_garbage_class, 164 | min_segment_length, 165 | max_segment_length, 166 | max_batch_duration, 167 | max_batch_cuts, 168 | num_workers, 169 | num_buckets, 170 | enhanced_manifest, 171 | profiler_output, 172 | force_overwrite, 173 | duration_tolerance, 174 | ): 175 | """ 176 | Enhance segments (represented by cuts). 177 | 178 | CUTS_PER_RECORDING: Lhotse cuts manifest containing cuts per recording 179 | CUTS_PER_SEGMENT: Lhotse cuts manifest containing cuts per segment (e.g. obtained using `trim-to-supervisions`) 180 | ENHANCED_DIR: Output directory for enhanced audio files 181 | """ 182 | if profiler_output is not None: 183 | import atexit 184 | import cProfile 185 | import pstats 186 | 187 | print("Profiling...") 188 | pr = cProfile.Profile() 189 | pr.enable() 190 | 191 | def exit(): 192 | pr.disable() 193 | print("Profiling completed") 194 | pstats.Stats(pr).sort_stats("cumulative").dump_stats(profiler_output) 195 | 196 | atexit.register(exit) 197 | 198 | if duration_tolerance is not None: 199 | set_audio_duration_mismatch_tolerance(duration_tolerance) 200 | 201 | enhanced_dir = Path(enhanced_dir) 202 | enhanced_dir.mkdir(exist_ok=True, parents=True) 203 | 204 | cuts = load_manifest_lazy(cuts_per_recording) 205 | cuts_per_segment = load_manifest_lazy(cuts_per_segment) 206 | 207 | if channels is not None: 208 | channels = [int(c) for c in channels.split(",")] 209 | cuts_per_segment = CutSet.from_cuts( 210 | fastcopy(cut, channel=channels) for cut in cuts_per_segment 211 | ) 212 | 213 | # Paranoia mode: ensure that cuts_per_recording have ids same as the recording_id 214 | cuts = CutSet.from_cuts(cut.with_id(cut.recording_id) for cut in cuts) 215 | 216 | logger.info("Aplying min/max segment length constraints") 217 | cuts_per_segment = cuts_per_segment.filter( 218 | lambda c: c.duration > min_segment_length 219 | ).cut_into_windows(duration=max_segment_length) 220 | 221 | logger.info("Initializing GSS enhancer") 222 | enhancer = get_enhancer( 223 | cuts=cuts, 224 | bss_iterations=bss_iterations, 225 | context_duration=context_duration, 226 | activity_garbage_class=use_garbage_class, 227 | wpe=use_wpe, 228 | ) 229 | 230 | logger.info(f"Enhancing {len(frozenset(c.id for c in cuts_per_segment))} segments") 231 | begin = time.time() 232 | num_errors, out_cuts = enhancer.enhance_cuts( 233 | cuts_per_segment, 234 | enhanced_dir, 235 | max_batch_duration=max_batch_duration, 236 | max_batch_cuts=max_batch_cuts, 237 | num_workers=num_workers, 238 | num_buckets=num_buckets, 239 | force_overwrite=force_overwrite, 240 | ) 241 | end = time.time() 242 | logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors") 243 | 244 | if enhanced_manifest is not None: 245 | logger.info(f"Saving enhanced cuts manifest to {enhanced_manifest}") 246 | out_cuts = post_process_manifests(out_cuts, enhanced_dir) 247 | out_cuts.to_file(enhanced_manifest) 248 | 249 | 250 | @enhance.command(name="recording") 251 | @click.argument( 252 | "recording", 253 | type=click.Path(exists=True), 254 | ) 255 | @click.argument( 256 | "rttm", 257 | type=click.Path(exists=True), 258 | ) 259 | @click.argument( 260 | "enhanced_dir", 261 | type=click.Path(), 262 | ) 263 | @click.option( 264 | "--recording-id", 265 | type=str, 266 | default=None, 267 | help="Name of recording (will be used to get corresponding segments from RTTM)", 268 | ) 269 | @common_options 270 | def recording_( 271 | recording, 272 | rttm, 273 | enhanced_dir, 274 | recording_id, 275 | channels, 276 | bss_iterations, 277 | use_wpe, 278 | context_duration, 279 | use_garbage_class, 280 | min_segment_length, 281 | max_segment_length, 282 | max_batch_duration, 283 | max_batch_cuts, 284 | num_workers, 285 | num_buckets, 286 | enhanced_manifest, 287 | profiler_output, 288 | force_overwrite, 289 | ): 290 | """ 291 | Enhance a single recording using an RTTM file. 292 | 293 | RECORDING: Path to a multi-channel recording 294 | RTTM: Path to an RTTM file containing speech activity 295 | ENHANCED_DIR: Output directory for enhanced audio files 296 | """ 297 | if profiler_output is not None: 298 | import atexit 299 | import cProfile 300 | import pstats 301 | 302 | print("Profiling...") 303 | pr = cProfile.Profile() 304 | pr.enable() 305 | 306 | def exit(): 307 | pr.disable() 308 | print("Profiling completed") 309 | pstats.Stats(pr).sort_stats("cumulative").dump_stats(profiler_output) 310 | 311 | atexit.register(exit) 312 | 313 | enhanced_dir = Path(enhanced_dir) 314 | enhanced_dir.mkdir(exist_ok=True, parents=True) 315 | 316 | cut = Recording.from_file(recording, recording_id=recording_id).to_cut() 317 | if channels is not None: 318 | channels = [int(c) for c in channels.split(",")] 319 | cut = fastcopy(cut, channel=channels) 320 | 321 | supervisions = SupervisionSet.from_rttm(rttm).filter( 322 | lambda s: s.recording_id == cut.id 323 | ) 324 | # Modify channel IDs to match the recording 325 | supervisions = SupervisionSet.from_segments( 326 | fastcopy(s, channel=cut.channel) for s in supervisions 327 | ) 328 | cut.supervisions = supervisions 329 | 330 | # Create a cuts manifest with a single cut for the recording 331 | cuts = CutSet.from_cuts([cut]) 332 | 333 | # Create segment-wise cuts 334 | cuts_per_segment = cuts.trim_to_supervisions( 335 | keep_overlapping=False, keep_all_channels=True 336 | ) 337 | 338 | logger.info("Aplying min/max segment length constraints") 339 | cuts_per_segment = cuts_per_segment.filter( 340 | lambda c: c.duration > min_segment_length 341 | ).cut_into_windows(duration=max_segment_length) 342 | 343 | logger.info("Initializing GSS enhancer") 344 | enhancer = get_enhancer( 345 | cuts=cuts, 346 | bss_iterations=bss_iterations, 347 | context_duration=context_duration, 348 | activity_garbage_class=use_garbage_class, 349 | wpe=use_wpe, 350 | ) 351 | 352 | logger.info(f"Enhancing {len(frozenset(c.id for c in cuts_per_segment))} segments") 353 | begin = time.time() 354 | num_errors, out_cuts = enhancer.enhance_cuts( 355 | cuts_per_segment, 356 | enhanced_dir, 357 | max_batch_duration=max_batch_duration, 358 | max_batch_cuts=max_batch_cuts, 359 | num_workers=num_workers, 360 | num_buckets=num_buckets, 361 | force_overwrite=force_overwrite, 362 | ) 363 | end = time.time() 364 | logger.info(f"Finished in {end-begin:.2f}s with {num_errors} errors") 365 | 366 | if enhanced_manifest is not None: 367 | logger.info(f"Saving enhanced cuts manifest to {enhanced_manifest}") 368 | out_cuts = post_process_manifests(out_cuts, enhanced_dir) 369 | out_cuts.to_file(enhanced_manifest) 370 | -------------------------------------------------------------------------------- /gss/core/enhancer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from concurrent.futures import ThreadPoolExecutor 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from types import SimpleNamespace 6 | 7 | import cupy as cp 8 | import numpy as np 9 | import soundfile as sf 10 | from lhotse import CutSet, Recording, RecordingSet, SupervisionSegment, SupervisionSet 11 | from lhotse.utils import add_durations, compute_num_samples 12 | from torch.utils.data import DataLoader 13 | 14 | from gss.core import GSS, WPE, Activity, Beamformer 15 | from gss.utils.data_utils import ( 16 | GssDataset, 17 | activity_time_to_frequency, 18 | create_sampler, 19 | start_end_context_frames, 20 | ) 21 | 22 | logging.basicConfig( 23 | format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", 24 | datefmt="%Y-%m-%d:%H:%M:%S", 25 | level=logging.INFO, 26 | ) 27 | 28 | 29 | def get_enhancer( 30 | cuts, 31 | context_duration=15, # 15 seconds 32 | wpe=True, 33 | wpe_tabs=10, 34 | wpe_delay=2, 35 | wpe_iterations=3, 36 | wpe_psd_context=0, 37 | activity_garbage_class=True, 38 | stft_size=1024, 39 | stft_shift=256, 40 | stft_fading=True, 41 | bss_iterations=20, 42 | bss_iterations_post=1, 43 | bf_drop_context=True, 44 | postfilter=None, 45 | ): 46 | assert wpe is True or wpe is False, wpe 47 | assert len(cuts) > 0 48 | 49 | sampling_rate = cuts[0].recording.sampling_rate 50 | 51 | return Enhancer( 52 | context_duration=context_duration, 53 | wpe_block=WPE( 54 | taps=wpe_tabs, 55 | delay=wpe_delay, 56 | iterations=wpe_iterations, 57 | psd_context=wpe_psd_context, 58 | ) 59 | if wpe 60 | else None, 61 | activity=Activity( 62 | garbage_class=activity_garbage_class, 63 | cuts=cuts, 64 | ), 65 | gss_block=GSS( 66 | iterations=bss_iterations, 67 | iterations_post=bss_iterations_post, 68 | ), 69 | bf_drop_context=bf_drop_context, 70 | bf_block=Beamformer( 71 | postfilter=postfilter, 72 | ), 73 | stft_size=stft_size, 74 | stft_shift=stft_shift, 75 | stft_fading=stft_fading, 76 | sampling_rate=sampling_rate, 77 | ) 78 | 79 | 80 | @dataclass 81 | class Enhancer: 82 | """ 83 | This class creates enhancement context (with speaker activity) for the sessions, and 84 | performs the enhancement. 85 | """ 86 | 87 | wpe_block: WPE 88 | activity: Activity 89 | gss_block: GSS 90 | bf_block: Beamformer 91 | 92 | bf_drop_context: bool 93 | 94 | stft_size: int 95 | stft_shift: int 96 | stft_fading: bool 97 | 98 | context_duration: float # e.g. 15 99 | sampling_rate: int 100 | 101 | def stft(self, x): 102 | from gss.core.stft_module import stft 103 | 104 | return stft( 105 | x, 106 | size=self.stft_size, 107 | shift=self.stft_shift, 108 | fading=self.stft_fading, 109 | ) 110 | 111 | def istft(self, X): 112 | from gss.core.stft_module import istft 113 | 114 | return istft( 115 | X, 116 | size=self.stft_size, 117 | shift=self.stft_shift, 118 | fading=self.stft_fading, 119 | ) 120 | 121 | def enhance_cuts( 122 | self, 123 | cuts, 124 | exp_dir, 125 | max_batch_duration=None, 126 | max_batch_cuts=None, 127 | num_buckets=2, 128 | num_workers=1, 129 | force_overwrite=False, 130 | ): 131 | """ 132 | Enhance the given CutSet. 133 | """ 134 | num_error = 0 135 | out_cuts = [] # list of enhanced cuts 136 | 137 | # Create the dataset, sampler, and data loader 138 | gss_dataset = GssDataset( 139 | context_duration=self.context_duration, activity=self.activity 140 | ) 141 | gss_sampler = create_sampler( 142 | cuts, 143 | max_duration=max_batch_duration, 144 | max_cuts=max_batch_cuts, 145 | num_buckets=num_buckets, 146 | ) 147 | dl = DataLoader( 148 | gss_dataset, 149 | sampler=gss_sampler, 150 | batch_size=None, 151 | num_workers=num_workers, 152 | persistent_workers=False, 153 | ) 154 | 155 | def _save_worker(orig_cuts, x_hat, recording_id, speaker): 156 | out_dir = exp_dir / recording_id 157 | enhanced_recordings = [] 158 | enhanced_supervisions = [] 159 | offset = 0 160 | for cut in orig_cuts: 161 | save_path = Path( 162 | f"{recording_id}-{speaker}-{round(100*cut.start):06d}_{round(100*cut.end):06d}.flac" 163 | ) 164 | if force_overwrite or not (out_dir / save_path).exists(): 165 | st = compute_num_samples(offset, self.sampling_rate) 166 | en = st + compute_num_samples(cut.duration, self.sampling_rate) 167 | x_hat_cut = x_hat[:, st:en] 168 | logging.debug("Saving enhanced signal") 169 | sf.write( 170 | file=str(out_dir / save_path), 171 | data=x_hat_cut.transpose(), 172 | samplerate=self.sampling_rate, 173 | format="FLAC", 174 | ) 175 | # Update offset for the next cut 176 | offset = add_durations( 177 | offset, cut.duration, sampling_rate=self.sampling_rate 178 | ) 179 | else: 180 | logging.info(f"File {save_path} already exists. Skipping.") 181 | # add enhanced recording to list 182 | enhanced_recordings.append(Recording.from_file(out_dir / save_path)) 183 | # modify supervision channels since enhanced recording has only 1 channel 184 | enhanced_supervisions.extend( 185 | [ 186 | SupervisionSegment( 187 | id=str(save_path), 188 | recording_id=str(save_path), 189 | start=segment.start, 190 | duration=segment.duration, 191 | channel=0, 192 | text=segment.text, 193 | language=segment.language, 194 | speaker=segment.speaker, 195 | ) 196 | for segment in cut.supervisions 197 | ] 198 | ) 199 | return enhanced_recordings, enhanced_supervisions 200 | 201 | # Iterate over batches 202 | futures = [] 203 | total_processed = 0 204 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 205 | for batch_idx, batch in enumerate(dl): 206 | batch = SimpleNamespace(**batch) 207 | logging.info( 208 | f"Processing batch {batch_idx+1} {batch.recording_id, batch.speaker}: " 209 | f"{len(batch.orig_cuts)} segments = {batch.duration}s (tot processed: {total_processed} segments)" 210 | ) 211 | total_processed += len(batch.orig_cuts) 212 | 213 | out_dir = exp_dir / batch.recording_id 214 | out_dir.mkdir(parents=True, exist_ok=True) 215 | 216 | file_exists = [] 217 | if not force_overwrite: 218 | for cut in batch.orig_cuts: 219 | save_path = Path( 220 | f"{batch.recording_id}-{batch.speaker}-{round(100*cut.start):06d}_{round(100*cut.end):06d}.flac" 221 | ) 222 | file_exists.append((out_dir / save_path).exists()) 223 | 224 | if all(file_exists): 225 | logging.info("All files already exist. Skipping.") 226 | continue 227 | 228 | # Sometimes the segment may be large and cause OOM issues in CuPy. If this 229 | # happens, we increasingly chunk it up into smaller segments until it can 230 | # be processed without breaking. 231 | num_chunks = 1 232 | max_chunks = self.stft_size // 2 + 1 233 | while num_chunks <= max_chunks: 234 | try: 235 | x_hat = self.enhance_batch( 236 | batch.audio, 237 | batch.activity, 238 | batch.speaker_idx, 239 | num_chunks=num_chunks, 240 | left_context=batch.left_context, 241 | right_context=batch.right_context, 242 | ) 243 | break 244 | except cp.cuda.memory.OutOfMemoryError: 245 | num_chunks = num_chunks + 1 246 | if num_chunks <= max_chunks: 247 | logging.warning( 248 | f"Out of memory error while processing the batch. Trying again with {num_chunks} chunks." 249 | ) 250 | except Exception as e: 251 | logging.error(f"Error enhancing batch: {e}") 252 | num_error += 1 253 | # Keep the original signal (only load channel 0) 254 | # NOTE (@desh2608): One possible issue here is that the whole batch 255 | # may fail even if the issue is only due to one segment. We may 256 | # want to handle this case separately. 257 | x_hat = batch.audio[0:1].cpu().numpy() 258 | break 259 | if num_chunks > max_chunks: 260 | # OOM error 261 | logging.error( 262 | f"Out of memory error while processing the batch. " 263 | f"Reached the maximum number of chunks, exiting." 264 | f"Please reduce --max-batch-duration." 265 | ) 266 | raise cp.cuda.memory.OutOfMemoryError 267 | 268 | # Save the enhanced cut to disk 269 | futures.append( 270 | executor.submit( 271 | _save_worker, 272 | batch.orig_cuts, 273 | x_hat, 274 | batch.recording_id, 275 | batch.speaker, 276 | ) 277 | ) 278 | 279 | out_recordings = [] 280 | out_supervisions = [] 281 | for future in futures: 282 | enhanced_recordings, enhanced_supervisions = future.result() 283 | out_recordings.extend(enhanced_recordings) 284 | out_supervisions.extend(enhanced_supervisions) 285 | 286 | out_recordings = RecordingSet.from_recordings(out_recordings) 287 | out_supervisions = SupervisionSet.from_segments(out_supervisions) 288 | return num_error, CutSet.from_manifests( 289 | recordings=out_recordings, supervisions=out_supervisions 290 | ) 291 | 292 | def enhance_batch( 293 | self, obs, activity, speaker_id, num_chunks=1, left_context=0, right_context=0 294 | ): 295 | logging.debug(f"Converting activity to frequency domain") 296 | activity_freq = activity_time_to_frequency( 297 | activity, 298 | stft_window_length=self.stft_size, 299 | stft_shift=self.stft_shift, 300 | stft_fading=self.stft_fading, 301 | stft_pad=True, 302 | ) 303 | 304 | # Convert to cupy array (putting it on the GPU) 305 | obs = cp.asarray(obs) 306 | 307 | logging.debug(f"Computing STFT") 308 | Obs = self.stft(obs) 309 | 310 | D, T, F = Obs.shape 311 | 312 | # Process observation in chunks 313 | # Use freq axis as suggested by Christoph Boedekker 314 | # see https://github.com/desh2608/gss/issues/33 315 | chunk_size = int(np.ceil(F / num_chunks)) 316 | masks = [] 317 | for i in range(num_chunks): 318 | st = i * chunk_size 319 | en = min(F, (i + 1) * chunk_size) 320 | Obs_chunk = Obs[:, :, st:en] 321 | 322 | logging.debug(f"Applying WPE") 323 | if self.wpe_block is not None: 324 | Obs_chunk = self.wpe_block(Obs_chunk) 325 | # Replace the chunk in the original array (to save memory) 326 | Obs[:, :, st:en] = Obs_chunk 327 | 328 | logging.debug(f"Computing GSS masks") 329 | masks_chunk = self.gss_block(Obs_chunk, activity_freq) 330 | masks.append(masks_chunk) 331 | 332 | masks = cp.concatenate(masks, axis=-1) # concat along freq 333 | if self.bf_drop_context: 334 | logging.debug("Dropping context for beamforming") 335 | left_context_frames, right_context_frames = start_end_context_frames( 336 | left_context, 337 | right_context, 338 | stft_size=self.stft_size, 339 | stft_shift=self.stft_shift, 340 | stft_fading=self.stft_fading, 341 | ) 342 | logging.debug( 343 | f"left_context_frames: {left_context_frames}, right_context_frames: {right_context_frames}" 344 | ) 345 | 346 | masks[:, :left_context_frames, :] = 0 347 | if right_context_frames > 0: 348 | masks[:, -right_context_frames:, :] = 0 349 | 350 | target_mask = masks[speaker_id] 351 | distortion_mask = cp.sum(masks, axis=0) - target_mask 352 | 353 | logging.debug("Applying beamforming with computed masks") 354 | X_hat = [] 355 | for i in range(num_chunks): 356 | st = i * chunk_size 357 | en = min(F, (i + 1) * chunk_size) 358 | X_hat_chunk = self.bf_block( 359 | Obs[:, :, st:en], 360 | target_mask=target_mask[:, st:en], 361 | distortion_mask=distortion_mask[:, st:en], 362 | ) 363 | X_hat.append(X_hat_chunk) 364 | 365 | X_hat = cp.concatenate(X_hat, axis=1) # freq axis again 366 | 367 | logging.debug("Computing inverse STFT") 368 | x_hat = self.istft(X_hat) # returns a numpy array 369 | 370 | if x_hat.ndim == 1: 371 | x_hat = x_hat[np.newaxis, :] 372 | 373 | # Trim x_hat to original length of cut 374 | if right_context > 0: 375 | x_hat = x_hat[:, left_context:-right_context] 376 | elif right_context == 0: 377 | x_hat = x_hat[:, left_context:] 378 | else: 379 | logging.warning( 380 | f"Right context is less than zero. Only left context is used." 381 | ) 382 | x_hat = x_hat[:, left_context:] 383 | 384 | return x_hat 385 | -------------------------------------------------------------------------------- /recipes/libricss/utils/queue.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | use strict; 3 | use warnings; 4 | 5 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey). 6 | # 2014 Vimal Manohar (Johns Hopkins University) 7 | # Apache 2.0. 8 | 9 | use File::Basename; 10 | use Cwd; 11 | use Getopt::Long; 12 | 13 | # queue.pl has the same functionality as run.pl, except that 14 | # it runs the job in question on the queue (Sun GridEngine). 15 | # This version of queue.pl uses the task array functionality 16 | # of the grid engine. Note: it's different from the queue.pl 17 | # in the s4 and earlier scripts. 18 | 19 | # The script now supports configuring the queue system using a config file 20 | # (default in conf/queue.conf; but can be passed specified with --config option) 21 | # and a set of command line options. 22 | # The current script handles: 23 | # 1) Normal configuration arguments 24 | # For e.g. a command line option of "--gpu 1" could be converted into the option 25 | # "-q g.q -l gpu=1" to qsub. How the CLI option is handled is determined by a 26 | # line in the config file like 27 | # gpu=* -q g.q -l gpu=$0 28 | # $0 here in the line is replaced with the argument read from the CLI and the 29 | # resulting string is passed to qsub. 30 | # 2) Special arguments to options such as 31 | # gpu=0 32 | # If --gpu 0 is given in the command line, then no special "-q" is given. 33 | # 3) Default argument 34 | # default gpu=0 35 | # If --gpu option is not passed in the command line, then the script behaves as 36 | # if --gpu 0 was passed since 0 is specified as the default argument for that 37 | # option 38 | # 4) Arbitrary options and arguments. 39 | # Any command line option starting with '--' and its argument would be handled 40 | # as long as its defined in the config file. 41 | # 5) Default behavior 42 | # If the config file that is passed using is not readable, then the script 43 | # behaves as if the queue has the following config file: 44 | # $ cat conf/queue.conf 45 | # # Default configuration 46 | # command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* 47 | # option mem=* -l mem_free=$0,ram_free=$0 48 | # option mem=0 # Do not add anything to qsub_opts 49 | # option num_threads=* -pe smp $0 50 | # option num_threads=1 # Do not add anything to qsub_opts 51 | # option max_jobs_run=* -tc $0 52 | # default gpu=0 53 | # option gpu=0 -q all.q 54 | # option gpu=* -l gpu=$0 -q g.q 55 | 56 | my $qsub_opts = ""; 57 | my $sync = 0; 58 | my $num_threads = 1; 59 | my $gpu = 0; 60 | 61 | my $config = "conf/queue.conf"; 62 | 63 | my %cli_options = (); 64 | 65 | my $jobname; 66 | my $jobstart; 67 | my $jobend; 68 | my $array_job = 0; 69 | my $sge_job_id; 70 | 71 | sub print_usage() { 72 | print STDERR 73 | "Usage: queue.pl [options] [JOB=1:n] log-file command-line arguments...\n" . 74 | "e.g.: queue.pl foo.log echo baz\n" . 75 | " (which will echo \"baz\", with stdout and stderr directed to foo.log)\n" . 76 | "or: queue.pl -q all.q\@xyz foo.log echo bar \| sed s/bar/baz/ \n" . 77 | " (which is an example of using a pipe; you can provide other escaped bash constructs)\n" . 78 | "or: queue.pl -q all.q\@qyz JOB=1:10 foo.JOB.log echo JOB \n" . 79 | " (which illustrates the mechanism to submit parallel jobs; note, you can use \n" . 80 | " another string other than JOB)\n" . 81 | "Note: if you pass the \"-sync y\" option to qsub, this script will take note\n" . 82 | "and change its behavior. Otherwise it uses qstat to work out when the job finished\n" . 83 | "Options:\n" . 84 | " --config (default: $config)\n" . 85 | " --mem (e.g. --mem 2G, --mem 500M, \n" . 86 | " also support K and numbers mean bytes)\n" . 87 | " --num-threads (default: $num_threads)\n" . 88 | " --max-jobs-run \n" . 89 | " --gpu <0|1> (default: $gpu)\n"; 90 | exit 1; 91 | } 92 | 93 | sub caught_signal { 94 | if ( defined $sge_job_id ) { # Signal trapped after submitting jobs 95 | my $signal = $!; 96 | system ("qdel $sge_job_id"); 97 | print STDERR "Caught a signal: $signal , deleting SGE task: $sge_job_id and exiting\n"; 98 | exit(2); 99 | } 100 | } 101 | 102 | if (@ARGV < 2) { 103 | print_usage(); 104 | } 105 | 106 | for (my $x = 1; $x <= 2; $x++) { # This for-loop is to 107 | # allow the JOB=1:n option to be interleaved with the 108 | # options to qsub. 109 | while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) { 110 | my $switch = shift @ARGV; 111 | 112 | if ($switch eq "-V") { 113 | $qsub_opts .= "-V "; 114 | } else { 115 | my $argument = shift @ARGV; 116 | if ($argument =~ m/^--/) { 117 | print STDERR "WARNING: suspicious argument '$argument' to $switch; starts with '-'\n"; 118 | } 119 | if ($switch eq "-sync" && $argument =~ m/^[yY]/) { 120 | $sync = 1; 121 | $qsub_opts .= "$switch $argument "; 122 | } elsif ($switch eq "-pe") { # e.g. -pe smp 5 123 | my $argument2 = shift @ARGV; 124 | $qsub_opts .= "$switch $argument $argument2 "; 125 | $num_threads = $argument2; 126 | } elsif ($switch =~ m/^--/) { # Config options 127 | # Convert CLI option to variable name 128 | # by removing '--' from the switch and replacing any 129 | # '-' with a '_' 130 | $switch =~ s/^--//; 131 | $switch =~ s/-/_/g; 132 | $cli_options{$switch} = $argument; 133 | } else { # Other qsub options - passed as is 134 | $qsub_opts .= "$switch $argument "; 135 | } 136 | } 137 | } 138 | if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20 139 | $array_job = 1; 140 | $jobname = $1; 141 | $jobstart = $2; 142 | $jobend = $3; 143 | shift; 144 | if ($jobstart > $jobend) { 145 | die "queue.pl: invalid job range $ARGV[0]"; 146 | } 147 | if ($jobstart <= 0) { 148 | die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is a GridEngine limitation)."; 149 | } 150 | } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1. 151 | $array_job = 1; 152 | $jobname = $1; 153 | $jobstart = $2; 154 | $jobend = $2; 155 | shift; 156 | } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { 157 | print STDERR "queue.pl: Warning: suspicious first argument to queue.pl: $ARGV[0]\n"; 158 | } 159 | } 160 | 161 | if (@ARGV < 2) { 162 | print_usage(); 163 | } 164 | 165 | if (exists $cli_options{"config"}) { 166 | $config = $cli_options{"config"}; 167 | } 168 | 169 | my $default_config_file = <<'EOF'; 170 | # Default configuration 171 | command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* 172 | option mem=* -l mem_free=$0,ram_free=$0 173 | option mem=0 # Do not add anything to qsub_opts 174 | option num_threads=* -pe smp $0 175 | option num_threads=1 # Do not add anything to qsub_opts 176 | option max_jobs_run=* -tc $0 177 | default gpu=0 178 | option gpu=0 179 | option gpu=* -l gpu=$0 -q '*.q' 180 | EOF 181 | 182 | # Here the configuration options specified by the user on the command line 183 | # (e.g. --mem 2G) are converted to options to the qsub system as defined in 184 | # the config file. (e.g. if the config file has the line 185 | # "option mem=* -l ram_free=$0,mem_free=$0" 186 | # and the user has specified '--mem 2G' on the command line, the options 187 | # passed to queue system would be "-l ram_free=2G,mem_free=2G 188 | # A more detailed description of the ways the options would be handled is at 189 | # the top of this file. 190 | 191 | $SIG{INT} = \&caught_signal; 192 | $SIG{TERM} = \&caught_signal; 193 | 194 | my $opened_config_file = 1; 195 | 196 | open CONFIG, "<$config" or $opened_config_file = 0; 197 | 198 | my %cli_config_options = (); 199 | my %cli_default_options = (); 200 | 201 | if ($opened_config_file == 0 && exists($cli_options{"config"})) { 202 | print STDERR "Could not open config file $config\n"; 203 | exit(1); 204 | } elsif ($opened_config_file == 0 && !exists($cli_options{"config"})) { 205 | # Open the default config file instead 206 | open (CONFIG, "echo '$default_config_file' |") or die "Unable to open pipe\n"; 207 | $config = "Default config"; 208 | } 209 | 210 | my $qsub_cmd = ""; 211 | my $read_command = 0; 212 | 213 | while() { 214 | chomp; 215 | my $line = $_; 216 | $_ =~ s/\s*#.*//g; 217 | if ($_ eq "") { next; } 218 | if ($_ =~ /^command (.+)/) { 219 | $read_command = 1; 220 | $qsub_cmd = $1 . " "; 221 | } elsif ($_ =~ m/^option ([^=]+)=\* (.+)$/) { 222 | # Config option that needs replacement with parameter value read from CLI 223 | # e.g.: option mem=* -l mem_free=$0,ram_free=$0 224 | my $option = $1; # mem 225 | my $arg= $2; # -l mem_free=$0,ram_free=$0 226 | if ($arg !~ m:\$0:) { 227 | die "Unable to parse line '$line' in config file ($config)\n"; 228 | } 229 | if (exists $cli_options{$option}) { 230 | # Replace $0 with the argument read from command line. 231 | # e.g. "-l mem_free=$0,ram_free=$0" -> "-l mem_free=2G,ram_free=2G" 232 | $arg =~ s/\$0/$cli_options{$option}/g; 233 | $cli_config_options{$option} = $arg; 234 | } 235 | } elsif ($_ =~ m/^option ([^=]+)=(\S+)\s?(.*)$/) { 236 | # Config option that does not need replacement 237 | # e.g. option gpu=0 -q all.q 238 | my $option = $1; # gpu 239 | my $value = $2; # 0 240 | my $arg = $3; # -q all.q 241 | if (exists $cli_options{$option}) { 242 | $cli_default_options{($option,$value)} = $arg; 243 | } 244 | } elsif ($_ =~ m/^default (\S+)=(\S+)/) { 245 | # Default options. Used for setting default values to options i.e. when 246 | # the user does not specify the option on the command line 247 | # e.g. default gpu=0 248 | my $option = $1; # gpu 249 | my $value = $2; # 0 250 | if (!exists $cli_options{$option}) { 251 | # If the user has specified this option on the command line, then we 252 | # don't have to do anything 253 | $cli_options{$option} = $value; 254 | } 255 | } else { 256 | print STDERR "queue.pl: unable to parse line '$line' in config file ($config)\n"; 257 | exit(1); 258 | } 259 | } 260 | 261 | close(CONFIG); 262 | 263 | if ($read_command != 1) { 264 | print STDERR "queue.pl: config file ($config) does not contain the line \"command .*\"\n"; 265 | exit(1); 266 | } 267 | 268 | for my $option (keys %cli_options) { 269 | if ($option eq "config") { next; } 270 | if ($option eq "max_jobs_run" && $array_job != 1) { next; } 271 | my $value = $cli_options{$option}; 272 | 273 | if (exists $cli_default_options{($option,$value)}) { 274 | $qsub_opts .= "$cli_default_options{($option,$value)} "; 275 | } elsif (exists $cli_config_options{$option}) { 276 | $qsub_opts .= "$cli_config_options{$option} "; 277 | } else { 278 | if ($opened_config_file == 0) { $config = "default config file"; } 279 | die "queue.pl: Command line option $option not described in $config (or value '$value' not allowed)\n"; 280 | } 281 | } 282 | 283 | my $cwd = getcwd(); 284 | my $logfile = shift @ARGV; 285 | 286 | if ($array_job == 1 && $logfile !~ m/$jobname/ 287 | && $jobend > $jobstart) { 288 | print STDERR "queue.pl: you are trying to run a parallel job but " 289 | . "you are putting the output into just one log file ($logfile)\n"; 290 | exit(1); 291 | } 292 | 293 | # 294 | # Work out the command; quote escaping is done here. 295 | # Note: the rules for escaping stuff are worked out pretty 296 | # arbitrarily, based on what we want it to do. Some things that 297 | # we pass as arguments to queue.pl, such as "|", we want to be 298 | # interpreted by bash, so we don't escape them. Other things, 299 | # such as archive specifiers like 'ark:gunzip -c foo.gz|', we want 300 | # to be passed, in quotes, to the Kaldi program. Our heuristic 301 | # is that stuff with spaces in should be quoted. This doesn't 302 | # always work. 303 | # 304 | my $cmd = ""; 305 | 306 | foreach my $x (@ARGV) { 307 | if ($x =~ m/^\S+$/) { $cmd .= $x . " "; } # If string contains no spaces, take 308 | # as-is. 309 | elsif ($x =~ m:\":) { $cmd .= "'$x' "; } # else if no dbl-quotes, use single 310 | else { $cmd .= "\"$x\" "; } # else use double. 311 | } 312 | 313 | # 314 | # Work out the location of the script file, and open it for writing. 315 | # 316 | my $dir = dirname($logfile); 317 | my $base = basename($logfile); 318 | my $qdir = "$dir/q"; 319 | $qdir =~ s:/(log|LOG)/*q:/q:; # If qdir ends in .../log/q, make it just .../q. 320 | my $queue_logfile = "$qdir/$base"; 321 | 322 | if (!-d $dir) { system "mkdir -p $dir 2>/dev/null"; } # another job may be doing this... 323 | if (!-d $dir) { die "Cannot make the directory $dir\n"; } 324 | # make a directory called "q", 325 | # where we will put the log created by qsub... normally this doesn't contain 326 | # anything interesting, evertyhing goes to $logfile. 327 | # in $qdir/sync we'll put the done.* files... we try to keep this 328 | # directory small because it's transmitted over NFS many times. 329 | if (! -d "$qdir/sync") { 330 | system "mkdir -p $qdir/sync 2>/dev/null"; 331 | sleep(5); ## This is to fix an issue we encountered in denominator lattice creation, 332 | ## where if e.g. the exp/tri2b_denlats/log/15/q directory had just been 333 | ## created and the job immediately ran, it would die with an error because nfs 334 | ## had not yet synced. I'm also decreasing the acdirmin and acdirmax in our 335 | ## NFS settings to something like 5 seconds. 336 | } 337 | 338 | my $queue_array_opt = ""; 339 | if ($array_job == 1) { # It's an array job. 340 | $queue_array_opt = "-t $jobstart:$jobend"; 341 | $logfile =~ s/$jobname/\$SGE_TASK_ID/g; # This variable will get 342 | # replaced by qsub, in each job, with the job-id. 343 | $cmd =~ s/$jobname/\$\{SGE_TASK_ID\}/g; # same for the command... 344 | $queue_logfile =~ s/\.?$jobname//; # the log file in the q/ subdirectory 345 | # is for the queue to put its log, and this doesn't need the task array subscript 346 | # so we remove it. 347 | } 348 | 349 | # queue_scriptfile is as $queue_logfile [e.g. dir/q/foo.log] but 350 | # with the suffix .sh. 351 | my $queue_scriptfile = $queue_logfile; 352 | ($queue_scriptfile =~ s/\.[a-zA-Z]{1,5}$/.sh/) || ($queue_scriptfile .= ".sh"); 353 | if ($queue_scriptfile !~ m:^/:) { 354 | $queue_scriptfile = $cwd . "/" . $queue_scriptfile; # just in case. 355 | } 356 | 357 | # We'll write to the standard input of "qsub" (the file-handle Q), 358 | # the job that we want it to execute. 359 | # Also keep our current PATH around, just in case there was something 360 | # in it that we need (although we also source ./path.sh) 361 | 362 | my $syncfile = "$qdir/sync/done.$$"; 363 | 364 | unlink($queue_logfile, $syncfile); 365 | # 366 | # Write to the script file, and then close it. 367 | # 368 | open(Q, ">$queue_scriptfile") || die "Failed to write to $queue_scriptfile"; 369 | 370 | print Q "#!/bin/bash\n"; 371 | print Q "cd $cwd\n"; 372 | print Q ". ./path.sh\n"; 373 | print Q "( echo '#' Running on \`hostname\`\n"; 374 | print Q " echo '#' Started at \`date\`\n"; 375 | print Q " echo -n '# '; cat <$logfile\n"; 379 | print Q "time1=\`date +\"%s\"\`\n"; 380 | print Q " ( $cmd ) 2>>$logfile >>$logfile\n"; 381 | print Q "ret=\$?\n"; 382 | print Q "time2=\`date +\"%s\"\`\n"; 383 | print Q "echo '#' Accounting: time=\$((\$time2-\$time1)) threads=$num_threads >>$logfile\n"; 384 | print Q "echo '#' Finished at \`date\` with status \$ret >>$logfile\n"; 385 | print Q "[ \$ret -eq 137 ] && exit 100;\n"; # If process was killed (e.g. oom) it will exit with status 137; 386 | # let the script return with status 100 which will put it to E state; more easily rerunnable. 387 | if ($array_job == 0) { # not an array job 388 | print Q "touch $syncfile\n"; # so we know it's done. 389 | } else { 390 | print Q "touch $syncfile.\$SGE_TASK_ID\n"; # touch a bunch of sync-files. 391 | } 392 | print Q "exit \$[\$ret ? 1 : 0]\n"; # avoid status 100 which grid-engine 393 | print Q "## submitted with:\n"; # treats specially. 394 | $qsub_cmd .= "-o $queue_logfile $qsub_opts $queue_array_opt $queue_scriptfile >>$queue_logfile 2>&1"; 395 | print Q "# $qsub_cmd\n"; 396 | if (!close(Q)) { # close was not successful... || die "Could not close script file $shfile"; 397 | die "Failed to close the script file (full disk?)"; 398 | } 399 | chmod 0755, $queue_scriptfile; 400 | 401 | # This block submits the job to the queue. 402 | for (my $try = 1; $try < 5; $try++) { 403 | my $ret = system ($qsub_cmd); 404 | if ($ret != 0) { 405 | if ($sync && $ret == 256) { # this is the exit status when a job failed (bad exit status) 406 | if (defined $jobname) { 407 | $logfile =~ s/\$SGE_TASK_ID/*/g; 408 | } 409 | print STDERR "queue.pl: job writing to $logfile failed\n"; 410 | exit(1); 411 | } else { 412 | print STDERR "queue.pl: Error submitting jobs to queue (return status was $ret)\n"; 413 | print STDERR "queue log file is $queue_logfile, command was $qsub_cmd\n"; 414 | my $err = `tail $queue_logfile`; 415 | print STDERR "Output of qsub was: $err\n"; 416 | if ($err =~ m/gdi request/ || $err =~ m/qmaster/) { 417 | # When we get queue connectivity problems we usually see a message like: 418 | # Unable to run job: failed receiving gdi request response for mid=1 (got 419 | # syncron message receive timeout error).. 420 | my $waitfor = 20; 421 | print STDERR "queue.pl: It looks like the queue master may be inaccessible. " . 422 | " Trying again after $waitfor seconts\n"; 423 | sleep($waitfor); 424 | # ... and continue throught the loop. 425 | } else { 426 | exit(1); 427 | } 428 | } 429 | } else { 430 | last; # break from the loop. 431 | } 432 | } 433 | 434 | if (! $sync) { # We're not submitting with -sync y, so we 435 | # need to wait for the jobs to finish. We wait for the 436 | # sync-files we "touched" in the script to exist. 437 | my @syncfiles = (); 438 | if (!defined $jobname) { # not an array job. 439 | push @syncfiles, $syncfile; 440 | } else { 441 | for (my $jobid = $jobstart; $jobid <= $jobend; $jobid++) { 442 | push @syncfiles, "$syncfile.$jobid"; 443 | } 444 | } 445 | # We will need the sge_job_id, to check that job still exists 446 | { # This block extracts the numeric SGE job-id from the log file in q/. 447 | # It may be used later to query 'qstat' about the job. 448 | open(L, "<$queue_logfile") || die "Error opening log file $queue_logfile"; 449 | undef $sge_job_id; 450 | while () { 451 | if (m/Your job\S* (\d+)[. ].+ has been submitted/) { 452 | if (defined $sge_job_id) { 453 | die "Error: your job was submitted more than once (see $queue_logfile)"; 454 | } else { 455 | $sge_job_id = $1; 456 | } 457 | } 458 | } 459 | close(L); 460 | if (!defined $sge_job_id) { 461 | die "Error: log file $queue_logfile does not specify the SGE job-id."; 462 | } 463 | } 464 | my $check_sge_job_ctr=1; 465 | 466 | my $wait = 0.1; 467 | my $counter = 0; 468 | foreach my $f (@syncfiles) { 469 | # wait for the jobs to finish one by one. 470 | while (! -f $f) { 471 | sleep($wait); 472 | $wait *= 1.2; 473 | if ($wait > 3.0) { 474 | $wait = 3.0; # never wait more than 3 seconds. 475 | # the following (.kick) commands are basically workarounds for NFS bugs. 476 | if (rand() < 0.25) { # don't do this every time... 477 | if (rand() > 0.5) { 478 | system("touch $qdir/sync/.kick"); 479 | } else { 480 | unlink("$qdir/sync/.kick"); 481 | } 482 | } 483 | if ($counter++ % 10 == 0) { 484 | # This seems to kick NFS in the teeth to cause it to refresh the 485 | # directory. I've seen cases where it would indefinitely fail to get 486 | # updated, even though the file exists on the server. 487 | # Only do this every 10 waits (every 30 seconds) though, or if there 488 | # are many jobs waiting they can overwhelm the file server. 489 | system("ls $qdir/sync >/dev/null"); 490 | } 491 | } 492 | 493 | # The purpose of the next block is so that queue.pl can exit if the job 494 | # was killed without terminating. It's a bit complicated because (a) we 495 | # don't want to overload the qmaster by querying it too frequently), and 496 | # (b) sometimes the qmaster is unreachable or temporarily down, and we 497 | # don't want this to necessarily kill the job. 498 | if (($check_sge_job_ctr < 100 && ($check_sge_job_ctr++ % 10) == 0) || 499 | ($check_sge_job_ctr >= 100 && ($check_sge_job_ctr++ % 50) == 0)) { 500 | # Don't run qstat too often, avoid stress on SGE; the if-condition above 501 | # is designed to check every 10 waits at first, and eventually every 50 502 | # waits. 503 | if ( -f $f ) { next; } #syncfile appeared: OK. 504 | my $output = `qstat -j $sge_job_id 2>&1`; 505 | my $ret = $?; 506 | if ($ret >> 8 == 1 && $output !~ m/qmaster/ && 507 | $output !~ m/gdi request/) { 508 | # Don't consider immediately missing job as error, first wait some 509 | # time to make sure it is not just delayed creation of the syncfile. 510 | 511 | sleep(3); 512 | # Sometimes NFS gets confused and thinks it's transmitted the directory 513 | # but it hasn't, due to timestamp issues. Changing something in the 514 | # directory will usually fix that. 515 | system("touch $qdir/sync/.kick"); 516 | unlink("$qdir/sync/.kick"); 517 | if ( -f $f ) { next; } #syncfile appeared, ok 518 | sleep(7); 519 | system("touch $qdir/sync/.kick"); 520 | sleep(1); 521 | unlink("qdir/sync/.kick"); 522 | if ( -f $f ) { next; } #syncfile appeared, ok 523 | sleep(60); 524 | system("touch $qdir/sync/.kick"); 525 | sleep(1); 526 | unlink("$qdir/sync/.kick"); 527 | if ( -f $f ) { next; } #syncfile appeared, ok 528 | $f =~ m/\.(\d+)$/ || die "Bad sync-file name $f"; 529 | my $job_id = $1; 530 | if (defined $jobname) { 531 | $logfile =~ s/\$SGE_TASK_ID/$job_id/g; 532 | } 533 | my $last_line = `tail -n 1 $logfile`; 534 | if ($last_line =~ m/status 0$/ && (-M $logfile) < 0) { 535 | # if the last line of $logfile ended with "status 0" and 536 | # $logfile is newer than this program [(-M $logfile) gives the 537 | # time elapsed between file modification and the start of this 538 | # program], then we assume the program really finished OK, 539 | # and maybe something is up with the file system. 540 | print STDERR "**queue.pl: syncfile $f was not created but job seems\n" . 541 | "**to have finished OK. Probably your file-system has problems.\n" . 542 | "**This is just a warning.\n"; 543 | last; 544 | } else { 545 | chop $last_line; 546 | print STDERR "queue.pl: Error, unfinished job no " . 547 | "longer exists, log is in $logfile, last line is '$last_line', " . 548 | "syncfile is $f, return status of qstat was $ret\n" . 549 | "Possible reasons: a) Exceeded time limit? -> Use more jobs!" . 550 | " b) Shutdown/Frozen machine? -> Run again! Qmaster output " . 551 | "was: $output\n"; 552 | exit(1); 553 | } 554 | } elsif ($ret != 0) { 555 | print STDERR "queue.pl: Warning: qstat command returned status $ret (qstat -j $sge_job_id,$!)\n"; 556 | print STDERR "queue.pl: output was: $output"; 557 | } 558 | } 559 | } 560 | } 561 | unlink(@syncfiles); 562 | } 563 | 564 | # OK, at this point we are synced; we know the job is done. 565 | # But we don't know about its exit status. We'll look at $logfile for this. 566 | # First work out an array @logfiles of file-locations we need to 567 | # read (just one, unless it's an array job). 568 | my @logfiles = (); 569 | if (!defined $jobname) { # not an array job. 570 | push @logfiles, $logfile; 571 | } else { 572 | for (my $jobid = $jobstart; $jobid <= $jobend; $jobid++) { 573 | my $l = $logfile; 574 | $l =~ s/\$SGE_TASK_ID/$jobid/g; 575 | push @logfiles, $l; 576 | } 577 | } 578 | 579 | my $num_failed = 0; 580 | my $status = 1; 581 | foreach my $l (@logfiles) { 582 | my @wait_times = (0.1, 0.2, 0.2, 0.3, 0.5, 0.5, 1.0, 2.0, 5.0, 5.0, 5.0, 10.0, 25.0); 583 | for (my $iter = 0; $iter <= @wait_times; $iter++) { 584 | my $line = `tail -10 $l 2>/dev/null`; # Note: although this line should be the last 585 | # line of the file, I've seen cases where it was not quite the last line because 586 | # of delayed output by the process that was running, or processes it had called. 587 | # so tail -10 gives it a little leeway. 588 | if ($line =~ m/with status (\d+)/) { 589 | $status = $1; 590 | last; 591 | } else { 592 | if ($iter < @wait_times) { 593 | sleep($wait_times[$iter]); 594 | } else { 595 | if (! -f $l) { 596 | print STDERR "Log-file $l does not exist.\n"; 597 | } else { 598 | print STDERR "The last line of log-file $l does not seem to indicate the " 599 | . "return status as expected\n"; 600 | } 601 | exit(1); # Something went wrong with the queue, or the 602 | # machine it was running on, probably. 603 | } 604 | } 605 | } 606 | # OK, now we have $status, which is the return-status of 607 | # the command in the job. 608 | if ($status != 0) { $num_failed++; } 609 | } 610 | if ($num_failed == 0) { exit(0); } 611 | else { # we failed. 612 | if (@logfiles == 1) { 613 | if (defined $jobname) { $logfile =~ s/\$SGE_TASK_ID/$jobstart/g; } 614 | print STDERR "queue.pl: job failed with status $status, log is in $logfile\n"; 615 | if ($logfile =~ m/JOB/) { 616 | print STDERR "queue.pl: probably you forgot to put JOB=1:\$nj in your script.\n"; 617 | } 618 | } else { 619 | if (defined $jobname) { $logfile =~ s/\$SGE_TASK_ID/*/g; } 620 | my $numjobs = 1 + $jobend - $jobstart; 621 | print STDERR "queue.pl: $num_failed / $numjobs failed, log is in $logfile\n"; 622 | } 623 | exit(1); 624 | } 625 | -------------------------------------------------------------------------------- /recipes/libricss/utils/queue-ackgpu.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | use strict; 3 | use warnings; 4 | 5 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey). 6 | # 2014 Vimal Manohar (Johns Hopkins University) 7 | # Apache 2.0. 8 | 9 | use File::Basename; 10 | use Cwd; 11 | use Getopt::Long; 12 | 13 | # queue.pl has the same functionality as run.pl, except that 14 | # it runs the job in question on the queue (Sun GridEngine). 15 | # This version of queue.pl uses the task array functionality 16 | # of the grid engine. Note: it's different from the queue.pl 17 | # in the s4 and earlier scripts. 18 | 19 | # The script now supports configuring the queue system using a config file 20 | # (default in conf/queue.conf; but can be passed specified with --config option) 21 | # and a set of command line options. 22 | # The current script handles: 23 | # 1) Normal configuration arguments 24 | # For e.g. a command line option of "--gpu 1" could be converted into the option 25 | # "-q g.q -l gpu=1" to qsub. How the CLI option is handled is determined by a 26 | # line in the config file like 27 | # gpu=* -q g.q -l gpu=$0 28 | # $0 here in the line is replaced with the argument read from the CLI and the 29 | # resulting string is passed to qsub. 30 | # 2) Special arguments to options such as 31 | # gpu=0 32 | # If --gpu 0 is given in the command line, then no special "-q" is given. 33 | # 3) Default argument 34 | # default gpu=0 35 | # If --gpu option is not passed in the command line, then the script behaves as 36 | # if --gpu 0 was passed since 0 is specified as the default argument for that 37 | # option 38 | # 4) Arbitrary options and arguments. 39 | # Any command line option starting with '--' and its argument would be handled 40 | # as long as its defined in the config file. 41 | # 5) Default behavior 42 | # If the config file that is passed using is not readable, then the script 43 | # behaves as if the queue has the following config file: 44 | # $ cat conf/queue.conf 45 | # # Default configuration 46 | # command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* 47 | # option mem=* -l mem_free=$0,ram_free=$0 48 | # option mem=0 # Do not add anything to qsub_opts 49 | # option num_threads=* -pe smp $0 50 | # option num_threads=1 # Do not add anything to qsub_opts 51 | # option max_jobs_run=* -tc $0 52 | # default gpu=0 53 | # option gpu=0 -q all.q 54 | # option gpu=* -l gpu=$0 -q g.q 55 | 56 | my $qsub_opts = ""; 57 | my $sync = 0; 58 | my $num_threads = 1; 59 | my $gpu = 0; 60 | 61 | my $config = "conf/queue.conf"; 62 | 63 | my %cli_options = (); 64 | 65 | my $jobname; 66 | my $jobstart; 67 | my $jobend; 68 | my $array_job = 0; 69 | my $sge_job_id; 70 | 71 | sub print_usage() { 72 | print STDERR 73 | "Usage: queue.pl [options] [JOB=1:n] log-file command-line arguments...\n" . 74 | "e.g.: queue.pl foo.log echo baz\n" . 75 | " (which will echo \"baz\", with stdout and stderr directed to foo.log)\n" . 76 | "or: queue.pl -q all.q\@xyz foo.log echo bar \| sed s/bar/baz/ \n" . 77 | " (which is an example of using a pipe; you can provide other escaped bash constructs)\n" . 78 | "or: queue.pl -q all.q\@qyz JOB=1:10 foo.JOB.log echo JOB \n" . 79 | " (which illustrates the mechanism to submit parallel jobs; note, you can use \n" . 80 | " another string other than JOB)\n" . 81 | "Note: if you pass the \"-sync y\" option to qsub, this script will take note\n" . 82 | "and change its behavior. Otherwise it uses qstat to work out when the job finished\n" . 83 | "Options:\n" . 84 | " --config (default: $config)\n" . 85 | " --mem (e.g. --mem 2G, --mem 500M, \n" . 86 | " also support K and numbers mean bytes)\n" . 87 | " --num-threads (default: $num_threads)\n" . 88 | " --max-jobs-run \n" . 89 | " --gpu <0|1> (default: $gpu)\n"; 90 | exit 1; 91 | } 92 | 93 | sub caught_signal { 94 | if ( defined $sge_job_id ) { # Signal trapped after submitting jobs 95 | my $signal = $!; 96 | system ("qdel $sge_job_id"); 97 | print STDERR "Caught a signal: $signal , deleting SGE task: $sge_job_id and exiting\n"; 98 | exit(2); 99 | } 100 | } 101 | 102 | if (@ARGV < 2) { 103 | print_usage(); 104 | } 105 | 106 | for (my $x = 1; $x <= 2; $x++) { # This for-loop is to 107 | # allow the JOB=1:n option to be interleaved with the 108 | # options to qsub. 109 | while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) { 110 | my $switch = shift @ARGV; 111 | 112 | if ($switch eq "-V") { 113 | $qsub_opts .= "-V "; 114 | } else { 115 | my $argument = shift @ARGV; 116 | if ($argument =~ m/^--/) { 117 | print STDERR "WARNING: suspicious argument '$argument' to $switch; starts with '-'\n"; 118 | } 119 | if ($switch eq "-sync" && $argument =~ m/^[yY]/) { 120 | $sync = 1; 121 | $qsub_opts .= "$switch $argument "; 122 | } elsif ($switch eq "-pe") { # e.g. -pe smp 5 123 | my $argument2 = shift @ARGV; 124 | $qsub_opts .= "$switch $argument $argument2 "; 125 | $num_threads = $argument2; 126 | } elsif ($switch =~ m/^--/) { # Config options 127 | # Convert CLI option to variable name 128 | # by removing '--' from the switch and replacing any 129 | # '-' with a '_' 130 | $switch =~ s/^--//; 131 | $switch =~ s/-/_/g; 132 | $cli_options{$switch} = $argument; 133 | } else { # Other qsub options - passed as is 134 | $qsub_opts .= "$switch $argument "; 135 | } 136 | } 137 | } 138 | if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20 139 | $array_job = 1; 140 | $jobname = $1; 141 | $jobstart = $2; 142 | $jobend = $3; 143 | shift; 144 | if ($jobstart > $jobend) { 145 | die "queue.pl: invalid job range $ARGV[0]"; 146 | } 147 | if ($jobstart <= 0) { 148 | die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is a GridEngine limitation)."; 149 | } 150 | } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1. 151 | $array_job = 1; 152 | $jobname = $1; 153 | $jobstart = $2; 154 | $jobend = $2; 155 | shift; 156 | } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) { 157 | print STDERR "queue.pl: Warning: suspicious first argument to queue.pl: $ARGV[0]\n"; 158 | } 159 | } 160 | 161 | if (@ARGV < 2) { 162 | print_usage(); 163 | } 164 | 165 | if (exists $cli_options{"config"}) { 166 | $config = $cli_options{"config"}; 167 | } 168 | 169 | my $default_config_file = <<'EOF'; 170 | # Default configuration 171 | command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* 172 | option mem=* -l mem_free=$0,ram_free=$0 173 | option mem=0 # Do not add anything to qsub_opts 174 | option num_threads=* -pe smp $0 175 | option num_threads=1 # Do not add anything to qsub_opts 176 | option max_jobs_run=* -tc $0 177 | default gpu=0 178 | option gpu=0 179 | option gpu=* -l gpu=$0 -q '*.q' 180 | EOF 181 | 182 | # Here the configuration options specified by the user on the command line 183 | # (e.g. --mem 2G) are converted to options to the qsub system as defined in 184 | # the config file. (e.g. if the config file has the line 185 | # "option mem=* -l ram_free=$0,mem_free=$0" 186 | # and the user has specified '--mem 2G' on the command line, the options 187 | # passed to queue system would be "-l ram_free=2G,mem_free=2G 188 | # A more detailed description of the ways the options would be handled is at 189 | # the top of this file. 190 | 191 | $SIG{INT} = \&caught_signal; 192 | $SIG{TERM} = \&caught_signal; 193 | 194 | my $opened_config_file = 1; 195 | 196 | open CONFIG, "<$config" or $opened_config_file = 0; 197 | 198 | my %cli_config_options = (); 199 | my %cli_default_options = (); 200 | 201 | if ($opened_config_file == 0 && exists($cli_options{"config"})) { 202 | print STDERR "Could not open config file $config\n"; 203 | exit(1); 204 | } elsif ($opened_config_file == 0 && !exists($cli_options{"config"})) { 205 | # Open the default config file instead 206 | open (CONFIG, "echo '$default_config_file' |") or die "Unable to open pipe\n"; 207 | $config = "Default config"; 208 | } 209 | 210 | my $qsub_cmd = ""; 211 | my $read_command = 0; 212 | 213 | while() { 214 | chomp; 215 | my $line = $_; 216 | $_ =~ s/\s*#.*//g; 217 | if ($_ eq "") { next; } 218 | if ($_ =~ /^command (.+)/) { 219 | $read_command = 1; 220 | $qsub_cmd = $1 . " "; 221 | } elsif ($_ =~ m/^option ([^=]+)=\* (.+)$/) { 222 | # Config option that needs replacement with parameter value read from CLI 223 | # e.g.: option mem=* -l mem_free=$0,ram_free=$0 224 | my $option = $1; # mem 225 | my $arg= $2; # -l mem_free=$0,ram_free=$0 226 | if ($arg !~ m:\$0:) { 227 | die "Unable to parse line '$line' in config file ($config)\n"; 228 | } 229 | if (exists $cli_options{$option}) { 230 | # Replace $0 with the argument read from command line. 231 | # e.g. "-l mem_free=$0,ram_free=$0" -> "-l mem_free=2G,ram_free=2G" 232 | $arg =~ s/\$0/$cli_options{$option}/g; 233 | $cli_config_options{$option} = $arg; 234 | } 235 | } elsif ($_ =~ m/^option ([^=]+)=(\S+)\s?(.*)$/) { 236 | # Config option that does not need replacement 237 | # e.g. option gpu=0 -q all.q 238 | my $option = $1; # gpu 239 | my $value = $2; # 0 240 | my $arg = $3; # -q all.q 241 | if (exists $cli_options{$option}) { 242 | $cli_default_options{($option,$value)} = $arg; 243 | } 244 | } elsif ($_ =~ m/^default (\S+)=(\S+)/) { 245 | # Default options. Used for setting default values to options i.e. when 246 | # the user does not specify the option on the command line 247 | # e.g. default gpu=0 248 | my $option = $1; # gpu 249 | my $value = $2; # 0 250 | if (!exists $cli_options{$option}) { 251 | # If the user has specified this option on the command line, then we 252 | # don't have to do anything 253 | $cli_options{$option} = $value; 254 | } 255 | } else { 256 | print STDERR "queue.pl: unable to parse line '$line' in config file ($config)\n"; 257 | exit(1); 258 | } 259 | } 260 | 261 | close(CONFIG); 262 | 263 | if ($read_command != 1) { 264 | print STDERR "queue.pl: config file ($config) does not contain the line \"command .*\"\n"; 265 | exit(1); 266 | } 267 | 268 | for my $option (keys %cli_options) { 269 | if ($option eq "config") { next; } 270 | if ($option eq "max_jobs_run" && $array_job != 1) { next; } 271 | my $value = $cli_options{$option}; 272 | 273 | if (exists $cli_default_options{($option,$value)}) { 274 | $qsub_opts .= "$cli_default_options{($option,$value)} "; 275 | } elsif (exists $cli_config_options{$option}) { 276 | $qsub_opts .= "$cli_config_options{$option} "; 277 | } else { 278 | if ($opened_config_file == 0) { $config = "default config file"; } 279 | die "queue.pl: Command line option $option not described in $config (or value '$value' not allowed)\n"; 280 | } 281 | } 282 | 283 | my $cwd = getcwd(); 284 | my $logfile = shift @ARGV; 285 | 286 | if ($array_job == 1 && $logfile !~ m/$jobname/ 287 | && $jobend > $jobstart) { 288 | print STDERR "queue.pl: you are trying to run a parallel job but " 289 | . "you are putting the output into just one log file ($logfile)\n"; 290 | exit(1); 291 | } 292 | 293 | # 294 | # Work out the command; quote escaping is done here. 295 | # Note: the rules for escaping stuff are worked out pretty 296 | # arbitrarily, based on what we want it to do. Some things that 297 | # we pass as arguments to queue.pl, such as "|", we want to be 298 | # interpreted by bash, so we don't escape them. Other things, 299 | # such as archive specifiers like 'ark:gunzip -c foo.gz|', we want 300 | # to be passed, in quotes, to the Kaldi program. Our heuristic 301 | # is that stuff with spaces in should be quoted. This doesn't 302 | # always work. 303 | # 304 | my $cmd = ""; 305 | 306 | foreach my $x (@ARGV) { 307 | if ($x =~ m/^\S+$/) { $cmd .= $x . " "; } # If string contains no spaces, take 308 | # as-is. 309 | elsif ($x =~ m:\":) { $cmd .= "'$x' "; } # else if no dbl-quotes, use single 310 | else { $cmd .= "\"$x\" "; } # else use double. 311 | } 312 | 313 | # 314 | # Work out the location of the script file, and open it for writing. 315 | # 316 | my $dir = dirname($logfile); 317 | my $base = basename($logfile); 318 | my $qdir = "$dir/q"; 319 | $qdir =~ s:/(log|LOG)/*q:/q:; # If qdir ends in .../log/q, make it just .../q. 320 | my $queue_logfile = "$qdir/$base"; 321 | 322 | if (!-d $dir) { system "mkdir -p $dir 2>/dev/null"; } # another job may be doing this... 323 | if (!-d $dir) { die "Cannot make the directory $dir\n"; } 324 | # make a directory called "q", 325 | # where we will put the log created by qsub... normally this doesn't contain 326 | # anything interesting, evertyhing goes to $logfile. 327 | # in $qdir/sync we'll put the done.* files... we try to keep this 328 | # directory small because it's transmitted over NFS many times. 329 | if (! -d "$qdir/sync") { 330 | system "mkdir -p $qdir/sync 2>/dev/null"; 331 | sleep(5); ## This is to fix an issue we encountered in denominator lattice creation, 332 | ## where if e.g. the exp/tri2b_denlats/log/15/q directory had just been 333 | ## created and the job immediately ran, it would die with an error because nfs 334 | ## had not yet synced. I'm also decreasing the acdirmin and acdirmax in our 335 | ## NFS settings to something like 5 seconds. 336 | } 337 | 338 | my $queue_array_opt = ""; 339 | if ($array_job == 1) { # It's an array job. 340 | $queue_array_opt = "-t $jobstart:$jobend"; 341 | $logfile =~ s/$jobname/\$SGE_TASK_ID/g; # This variable will get 342 | # replaced by qsub, in each job, with the job-id. 343 | $cmd =~ s/$jobname/\$\{SGE_TASK_ID\}/g; # same for the command... 344 | $queue_logfile =~ s/\.?$jobname//; # the log file in the q/ subdirectory 345 | # is for the queue to put its log, and this doesn't need the task array subscript 346 | # so we remove it. 347 | } 348 | 349 | # queue_scriptfile is as $queue_logfile [e.g. dir/q/foo.log] but 350 | # with the suffix .sh. 351 | my $queue_scriptfile = $queue_logfile; 352 | ($queue_scriptfile =~ s/\.[a-zA-Z]{1,5}$/.sh/) || ($queue_scriptfile .= ".sh"); 353 | if ($queue_scriptfile !~ m:^/:) { 354 | $queue_scriptfile = $cwd . "/" . $queue_scriptfile; # just in case. 355 | } 356 | 357 | # We'll write to the standard input of "qsub" (the file-handle Q), 358 | # the job that we want it to execute. 359 | # Also keep our current PATH around, just in case there was something 360 | # in it that we need (although we also source ./path.sh) 361 | 362 | my $syncfile = "$qdir/sync/done.$$"; 363 | 364 | unlink($queue_logfile, $syncfile); 365 | # 366 | # Write to the script file, and then close it. 367 | # 368 | open(Q, ">$queue_scriptfile") || die "Failed to write to $queue_scriptfile"; 369 | 370 | print Q "#!/usr/bin/env bash\n"; 371 | print Q "cd $cwd\n"; 372 | print Q ". ./path.sh\n"; 373 | print Q "( echo '#' Running on \`hostname\`\n"; 374 | print Q " echo '#' Started at \`date\`\n"; 375 | print Q " echo -n '# '; cat <$logfile\n"; 379 | print Q "if ! which acquire-gpu &> /dev/null; then\n"; 380 | print Q " echo 'command not found: acquire-gpu not found.'\n"; 381 | print Q " exit 1\n"; 382 | print Q "fi\n"; 383 | print Q "for _ in \$(seq $cli_options{'gpu'}); do source acquire-gpu; done\n"; 384 | print Q "time1=\`date +\"%s\"\`\n"; 385 | print Q " ( $cmd ) 2>>$logfile >>$logfile\n"; 386 | print Q "ret=\$?\n"; 387 | print Q "time2=\`date +\"%s\"\`\n"; 388 | print Q "echo '#' Accounting: time=\$((\$time2-\$time1)) threads=$num_threads >>$logfile\n"; 389 | print Q "echo '#' Finished at \`date\` with status \$ret >>$logfile\n"; 390 | print Q "[ \$ret -eq 137 ] && exit 100;\n"; # If process was killed (e.g. oom) it will exit with status 137; 391 | # let the script return with status 100 which will put it to E state; more easily rerunnable. 392 | if ($array_job == 0) { # not an array job 393 | print Q "touch $syncfile\n"; # so we know it's done. 394 | } else { 395 | print Q "touch $syncfile.\$SGE_TASK_ID\n"; # touch a bunch of sync-files. 396 | } 397 | print Q "exit \$[\$ret ? 1 : 0]\n"; # avoid status 100 which grid-engine 398 | print Q "## submitted with:\n"; # treats specially. 399 | $qsub_cmd .= "-o $queue_logfile $qsub_opts $queue_array_opt $queue_scriptfile >>$queue_logfile 2>&1"; 400 | print Q "# $qsub_cmd\n"; 401 | if (!close(Q)) { # close was not successful... || die "Could not close script file $shfile"; 402 | die "Failed to close the script file (full disk?)"; 403 | } 404 | chmod 0755, $queue_scriptfile; 405 | 406 | # This block submits the job to the queue. 407 | for (my $try = 1; $try < 5; $try++) { 408 | my $ret = system ($qsub_cmd); 409 | if ($ret != 0) { 410 | if ($sync && $ret == 256) { # this is the exit status when a job failed (bad exit status) 411 | if (defined $jobname) { 412 | $logfile =~ s/\$SGE_TASK_ID/*/g; 413 | } 414 | print STDERR "queue.pl: job writing to $logfile failed\n"; 415 | exit(1); 416 | } else { 417 | print STDERR "queue.pl: Error submitting jobs to queue (return status was $ret)\n"; 418 | print STDERR "queue log file is $queue_logfile, command was $qsub_cmd\n"; 419 | my $err = `tail $queue_logfile`; 420 | print STDERR "Output of qsub was: $err\n"; 421 | if ($err =~ m/gdi request/ || $err =~ m/qmaster/) { 422 | # When we get queue connectivity problems we usually see a message like: 423 | # Unable to run job: failed receiving gdi request response for mid=1 (got 424 | # syncron message receive timeout error).. 425 | my $waitfor = 20; 426 | print STDERR "queue.pl: It looks like the queue master may be inaccessible. " . 427 | " Trying again after $waitfor seconts\n"; 428 | sleep($waitfor); 429 | # ... and continue through the loop. 430 | } else { 431 | exit(1); 432 | } 433 | } 434 | } else { 435 | last; # break from the loop. 436 | } 437 | } 438 | 439 | if (! $sync) { # We're not submitting with -sync y, so we 440 | # need to wait for the jobs to finish. We wait for the 441 | # sync-files we "touched" in the script to exist. 442 | my @syncfiles = (); 443 | if (!defined $jobname) { # not an array job. 444 | push @syncfiles, $syncfile; 445 | } else { 446 | for (my $jobid = $jobstart; $jobid <= $jobend; $jobid++) { 447 | push @syncfiles, "$syncfile.$jobid"; 448 | } 449 | } 450 | # We will need the sge_job_id, to check that job still exists 451 | { # This block extracts the numeric SGE job-id from the log file in q/. 452 | # It may be used later to query 'qstat' about the job. 453 | open(L, "<$queue_logfile") || die "Error opening log file $queue_logfile"; 454 | undef $sge_job_id; 455 | while () { 456 | if (m/Your job\S* (\d+)[. ].+ has been submitted/) { 457 | if (defined $sge_job_id) { 458 | die "Error: your job was submitted more than once (see $queue_logfile)"; 459 | } else { 460 | $sge_job_id = $1; 461 | } 462 | } 463 | } 464 | close(L); 465 | if (!defined $sge_job_id) { 466 | die "Error: log file $queue_logfile does not specify the SGE job-id."; 467 | } 468 | } 469 | my $check_sge_job_ctr=1; 470 | 471 | my $wait = 0.1; 472 | my $counter = 0; 473 | foreach my $f (@syncfiles) { 474 | # wait for the jobs to finish one by one. 475 | while (! -f $f) { 476 | sleep($wait); 477 | $wait *= 1.2; 478 | if ($wait > 3.0) { 479 | $wait = 3.0; # never wait more than 3 seconds. 480 | # the following (.kick) commands are basically workarounds for NFS bugs. 481 | if (rand() < 0.25) { # don't do this every time... 482 | if (rand() > 0.5) { 483 | system("touch $qdir/sync/.kick"); 484 | } else { 485 | unlink("$qdir/sync/.kick"); 486 | } 487 | } 488 | if ($counter++ % 10 == 0) { 489 | # This seems to kick NFS in the teeth to cause it to refresh the 490 | # directory. I've seen cases where it would indefinitely fail to get 491 | # updated, even though the file exists on the server. 492 | # Only do this every 10 waits (every 30 seconds) though, or if there 493 | # are many jobs waiting they can overwhelm the file server. 494 | system("ls $qdir/sync >/dev/null"); 495 | } 496 | } 497 | 498 | # The purpose of the next block is so that queue.pl can exit if the job 499 | # was killed without terminating. It's a bit complicated because (a) we 500 | # don't want to overload the qmaster by querying it too frequently), and 501 | # (b) sometimes the qmaster is unreachable or temporarily down, and we 502 | # don't want this to necessarily kill the job. 503 | if (($check_sge_job_ctr < 100 && ($check_sge_job_ctr++ % 10) == 0) || 504 | ($check_sge_job_ctr >= 100 && ($check_sge_job_ctr++ % 50) == 0)) { 505 | # Don't run qstat too often, avoid stress on SGE; the if-condition above 506 | # is designed to check every 10 waits at first, and eventually every 50 507 | # waits. 508 | if ( -f $f ) { next; } #syncfile appeared: OK. 509 | my $output = `qstat -j $sge_job_id 2>&1`; 510 | my $ret = $?; 511 | if ($ret >> 8 == 1 && $output !~ m/qmaster/ && 512 | $output !~ m/gdi request/) { 513 | # Don't consider immediately missing job as error, first wait some 514 | # time to make sure it is not just delayed creation of the syncfile. 515 | 516 | sleep(3); 517 | # Sometimes NFS gets confused and thinks it's transmitted the directory 518 | # but it hasn't, due to timestamp issues. Changing something in the 519 | # directory will usually fix that. 520 | system("touch $qdir/sync/.kick"); 521 | unlink("$qdir/sync/.kick"); 522 | if ( -f $f ) { next; } #syncfile appeared, ok 523 | sleep(7); 524 | system("touch $qdir/sync/.kick"); 525 | sleep(1); 526 | unlink("qdir/sync/.kick"); 527 | if ( -f $f ) { next; } #syncfile appeared, ok 528 | sleep(60); 529 | system("touch $qdir/sync/.kick"); 530 | sleep(1); 531 | unlink("$qdir/sync/.kick"); 532 | if ( -f $f ) { next; } #syncfile appeared, ok 533 | $f =~ m/\.(\d+)$/ || die "Bad sync-file name $f"; 534 | my $job_id = $1; 535 | if (defined $jobname) { 536 | $logfile =~ s/\$SGE_TASK_ID/$job_id/g; 537 | } 538 | my $last_line = `tail -n 1 $logfile`; 539 | if ($last_line =~ m/status 0$/ && (-M $logfile) < 0) { 540 | # if the last line of $logfile ended with "status 0" and 541 | # $logfile is newer than this program [(-M $logfile) gives the 542 | # time elapsed between file modification and the start of this 543 | # program], then we assume the program really finished OK, 544 | # and maybe something is up with the file system. 545 | print STDERR "**queue.pl: syncfile $f was not created but job seems\n" . 546 | "**to have finished OK. Probably your file-system has problems.\n" . 547 | "**This is just a warning.\n"; 548 | last; 549 | } else { 550 | chop $last_line; 551 | print STDERR "queue.pl: Error, unfinished job no " . 552 | "longer exists, log is in $logfile, last line is '$last_line', " . 553 | "syncfile is $f, return status of qstat was $ret\n" . 554 | "Possible reasons: a) Exceeded time limit? -> Use more jobs!" . 555 | " b) Shutdown/Frozen machine? -> Run again! Qmaster output " . 556 | "was: $output\n"; 557 | exit(1); 558 | } 559 | } elsif ($ret != 0) { 560 | print STDERR "queue.pl: Warning: qstat command returned status $ret (qstat -j $sge_job_id,$!)\n"; 561 | print STDERR "queue.pl: output was: $output"; 562 | } 563 | } 564 | } 565 | } 566 | unlink(@syncfiles); 567 | } 568 | 569 | # OK, at this point we are synced; we know the job is done. 570 | # But we don't know about its exit status. We'll look at $logfile for this. 571 | # First work out an array @logfiles of file-locations we need to 572 | # read (just one, unless it's an array job). 573 | my @logfiles = (); 574 | if (!defined $jobname) { # not an array job. 575 | push @logfiles, $logfile; 576 | } else { 577 | for (my $jobid = $jobstart; $jobid <= $jobend; $jobid++) { 578 | my $l = $logfile; 579 | $l =~ s/\$SGE_TASK_ID/$jobid/g; 580 | push @logfiles, $l; 581 | } 582 | } 583 | 584 | my $num_failed = 0; 585 | my $status = 1; 586 | foreach my $l (@logfiles) { 587 | my @wait_times = (0.1, 0.2, 0.2, 0.3, 0.5, 0.5, 1.0, 2.0, 5.0, 5.0, 5.0, 10.0, 25.0); 588 | for (my $iter = 0; $iter <= @wait_times; $iter++) { 589 | my $line = `tail -10 $l 2>/dev/null`; # Note: although this line should be the last 590 | # line of the file, I've seen cases where it was not quite the last line because 591 | # of delayed output by the process that was running, or processes it had called. 592 | # so tail -10 gives it a little leeway. 593 | if ($line =~ m/with status (\d+)/) { 594 | $status = $1; 595 | last; 596 | } else { 597 | if ($iter < @wait_times) { 598 | sleep($wait_times[$iter]); 599 | } else { 600 | if (! -f $l) { 601 | print STDERR "Log-file $l does not exist.\n"; 602 | } else { 603 | print STDERR "The last line of log-file $l does not seem to indicate the " 604 | . "return status as expected\n"; 605 | } 606 | exit(1); # Something went wrong with the queue, or the 607 | # machine it was running on, probably. 608 | } 609 | } 610 | } 611 | # OK, now we have $status, which is the return-status of 612 | # the command in the job. 613 | if ($status != 0) { $num_failed++; } 614 | } 615 | if ($num_failed == 0) { exit(0); } 616 | else { # we failed. 617 | if (@logfiles == 1) { 618 | if (defined $jobname) { $logfile =~ s/\$SGE_TASK_ID/$jobstart/g; } 619 | print STDERR "queue.pl: job failed with status $status, log is in $logfile\n"; 620 | if ($logfile =~ m/JOB/) { 621 | print STDERR "queue.pl: probably you forgot to put JOB=1:\$nj in your script.\n"; 622 | } 623 | } else { 624 | if (defined $jobname) { $logfile =~ s/\$SGE_TASK_ID/*/g; } 625 | my $numjobs = 1 + $jobend - $jobstart; 626 | print STDERR "queue.pl: $num_failed / $numjobs failed, log is in $logfile\n"; 627 | } 628 | exit(1); 629 | } 630 | --------------------------------------------------------------------------------